In [2]:
import pandas as pd
from torch.utils.data import DataLoader, random_split
import torch
from torchvision import transforms
from ImageDataset import CustomImageDataset
from ResNet import ResNet

In [3]:
# check for gpu
gpu = None
if torch.cuda.is_available():
    gpu = torch.device('cuda')
elif torch.backends.mps.is_available():
    gpu = torch.device('mps')
else:
    gpu = torch.device('cpu')
print(gpu)

mps


## Training Data

### Data Source

**Snapshot Serengeti**: opensource dataset containing images from camera traps in the Serengeti, with crowdsourced labels with species and other information. More information can be found at https://www.zooniverse.org/projects/zooniverse/snapshot-serengeti and https://www.nature.com/articles/sdata201526#MOESM66. \
The csv files can be downloaded at https://datadryad.org/stash/dataset/doi:10.5061/dryad.5pt92#usage

### Features

**consensus_data.csv**: `CaptureEventID` `NumImages` `SiteID` `LocationX` `LocationY` `NumSpecies` `Species` `Count` `Standing` `Resting` `Moving` `Eating` `Interacting` `Babies` `NumClassifications` `NumVotes` `NumBlanks` `Evenness` \
**all_images.csv**: `CaptureEventID` `URL_Info`

### Data Cleanup:

For each entry, we need to get the image from the url https://snapshotserengeti.s3.msi.umn.edu/`URL_Info`. We need to maintain the features `Species` `CaputureEventID` from `consensus_data.csv` and merge them into the `URL_Info` feature of `all_images.csv`.

In [4]:
# read in the data
consensus_data = pd.read_csv('data/consensus_data.csv')[['CaptureEventID', 'Species']]
images = pd.read_csv('data/all_images.csv')

# Create a dataframe with the image urls and species label
df = pd.merge(images, consensus_data, on='CaptureEventID')
df["URL"] = "https://snapshotserengeti.s3.msi.umn.edu/" + df["URL_Info"]
df = df[['URL', 'Species']]

# save a list of class names
classes = ['human', 'gazelleGrants', 'reedbuck', 'dikDik', 'zebra', 'porcupine',
 'gazelleThomsons', 'hyenaSpotted', 'warthog', 'impala', 'elephant', 'giraffe',
 'mongoose', 'buffalo', 'hartebeest', 'guineaFowl', 'wildebeest', 'leopard',
 'ostrich', 'lionFemale', 'koriBustard', 'otherBird', 'batEaredFox', 'bushbuck',
 'jackal', 'cheetah', 'eland', 'aardwolf', 'hippopotamus', 'hyenaStriped',
 'aardvark', 'hare', 'baboon', 'vervetMonkey', 'waterbuck', 'secretaryBird',
 'serval', 'lionMale', 'topi', 'honeyBadger', 'rodents', 'wildcat', 'civet',
 'genet', 'caracal', 'rhinoceros', 'reptiles', 'zorilla']

In [16]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Create the train & test dataloaders
train_data = CustomImageDataset(df=df, transform=transform)
train_set, test_set = random_split(train_data, [0.7, 0.3])
train_dataloader = DataLoader(train_set, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=32, shuffle=True)

## Training ResNet Model

In [17]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
    num_batches = len(train_dataloader)
    print(f'Number of batches: {num_batches}')
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            dataloader = train_dataloader if phase == 'train' else test_dataloader
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloader:
                inputs = torch.tensor(inputs)
                labels = torch.tensor([classes.index(label) for label in labels])
                inputs, labels = inputs.to(gpu), labels.to(gpu)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                loss = loss.item() * inputs.size(0)
                corrects = torch.sum(preds == labels.data)

                print(f'Batch Loss: {loss:.4f}, Correct: {corrects.item()}')

                running_loss += loss
                running_corrects += corrects

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

    return model

In [19]:
# load the model
model = ResNet(num_classes=48)
model.to(gpu)

# Define the loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Train the model
epochs = 5
model = train_model(model, loss_fn, optimizer, scheduler, num_epochs=epochs)
torch.save(model.state_dict(), 'elephant_classifier_resnet50.pth')


Number of batches: 19473
Epoch 1/5
----------


  inputs = torch.tensor(inputs)


Batch Loss: 122.8931, Correct: 1
Batch Loss: 122.2517, Correct: 2


KeyboardInterrupt: 