# Chapter 2: Our First Model

In [48]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES=True

## Setting up DataLoaders

We'll use the built-in dataset of `torchvision.datasets.ImageFolder` to quickly set up some dataloaders of downloaded cat and fish images. 

`check_image`  is a quick little function that is passed to the `is_valid_file` parameter in the ImageFolder and will do a sanity check to make sure PIL can actually open the file. We're going to use this in lieu of cleaning up the downloaded dataset.


In [49]:
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

Set up the transforms for every image:

* Resize to 64x64
* Convert to tensor
* Normalize using ImageNet mean & std


In [50]:
img_transforms = transforms.Compose([
    transforms.Resize((64,64)),    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225] )
    ])



In [51]:
train_data_path = "../data/train/"
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)

In [52]:
train_data

Dataset ImageFolder
    Number of datapoints: 787
    Root location: ../data/train/
    StandardTransform
Transform: Compose(
               Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [53]:
val_data_path = "../data/val/"
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms, is_valid_file=check_image)

In [54]:
val_data

Dataset ImageFolder
    Number of datapoints: 102
    Root location: ../data/val/
    StandardTransform
Transform: Compose(
               Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [55]:
test_data_path = "../data/test/"
test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms, is_valid_file=check_image) 

In [56]:
test_data

Dataset ImageFolder
    Number of datapoints: 160
    Root location: ../data/test/
    StandardTransform
Transform: Compose(
               Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [57]:
batch_size=64

In [58]:
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
val_data_loader  = torch.utils.data.DataLoader(val_data, batch_size=batch_size) 
test_data_loader  = torch.utils.data.DataLoader(test_data, batch_size=batch_size) 

In [59]:
train_data_loader

<torch.utils.data.dataloader.DataLoader at 0x15c052010>

## Our First Model, SimpleNet

SimpleNet is a very simple combination of three Linear layers and ReLu activations between them. Note that as we don't do a `softmax()` in our `forward()`, we will need to make sure we do it in our training function during the validation phase.

In [60]:
class SimpleNet(nn.Module):

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(12288, 1026)
        self.fc2 = nn.Linear(1026, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64,2)
    
    def forward(self, x):
        x = x.view(-1, 12288)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [61]:
simplenet = SimpleNet()

## Copy the model to GPU

Copy the model to the GPU if available.

In [62]:
# torch.cuda.is_available()
torch.backends.mps.is_available() 

True

In [63]:
if torch.backends.mps.is_available() :
    device = torch.device("mps") 
    print('mps')
else:
    device = torch.device("cpu")

simplenet.to(device)

mps


SimpleNet(
  (fc1): Linear(in_features=12288, out_features=1026, bias=True)
  (fc2): Linear(in_features=1026, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=64, bias=True)
  (fc4): Linear(in_features=64, out_features=2, bias=True)
)

## Create an optimizer

Here, we're just using Adam as our optimizer with a learning rate of 0.001.

In [64]:
optimizer = optim.Adam(simplenet.parameters(), lr=0.0001)

## Training 

Trains the model, copying batches to the GPU if required, calculating losses, optimizing the network and perform validation for each epoch.

In [65]:
(2.1/len(train_data_loader.dataset))

0.002668360864040661

In [66]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            print(f'Output Shape: {output.shape}')
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            print(f'F.softmax(output, dim=1): {F.softmax(output, dim=1)}')
            print(f'torch.max(F.softmax(output, dim=1), dim=1): {torch.max(F.softmax(output, dim=1), dim=1)}')
            print(f'torch.max(F.softmax(output, dim=1), dim=1)[1]: {torch.max(F.softmax(output, dim=1), dim=1)[1]}')
            print(f'targets{targets}')
            print(f'torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets):{torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)}')
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))

In [67]:
device

device(type='mps')

In [68]:
%time train(simplenet, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=10, device=device)

Output Shape: torch.Size([64, 2])
F.softmax(output, dim=1): tensor([[0.7359, 0.2641],
        [0.7060, 0.2940],
        [0.6308, 0.3692],
        [0.6010, 0.3990],
        [0.6318, 0.3682],
        [0.5479, 0.4521],
        [0.5776, 0.4224],
        [0.6278, 0.3722],
        [0.6686, 0.3314],
        [0.5658, 0.4342],
        [0.5807, 0.4193],
        [0.4460, 0.5540],
        [0.5721, 0.4279],
        [0.6888, 0.3112],
        [0.6401, 0.3599],
        [0.5085, 0.4915],
        [0.5705, 0.4295],
        [0.5636, 0.4364],
        [0.5470, 0.4530],
        [0.7900, 0.2100],
        [0.6888, 0.3112],
        [0.6076, 0.3924],
        [0.5580, 0.4420],
        [0.4451, 0.5549],
        [0.4898, 0.5102],
        [0.7616, 0.2384],
        [0.4045, 0.5955],
        [0.6095, 0.3905],
        [0.8587, 0.1413],
        [0.5010, 0.4990],
        [0.5398, 0.4602],
        [0.8677, 0.1323],
        [0.4833, 0.5167],
        [0.5372, 0.4628],
        [0.5188, 0.4812],
        [0.7277, 0.2723],
    

## Making predictions

Labels are in alphanumeric order, so `cat` will be 0, `fish` will be 1. We'll need to transform the image and also make sure that the resulting tensor is copied to the appropriate device before applying our model to it.

In [71]:
labels = ['cat','fish']

img = Image.open("../data/val/fish/100_1422.JPG") 
img = img_transforms(img).to(device)
img = torch.unsqueeze(img, 0)

simplenet.eval()
prediction = F.softmax(simplenet(img), dim=1)
prediction = prediction.argmax()
print(labels[prediction]) 

fish


## Saving Models

We can either save the entire model using `save` or just the parameters using `state_dict`. Using the latter is normally preferable, as it allows you to reuse parameters even if the model's structure changes (or apply parameters from one model to another).

In [72]:
torch.save(simplenet, "tmp/simplenet") 
simplenet = torch.load("tmp/simplenet")    


In [73]:
torch.save(simplenet.state_dict(), "tmp/simplenet")    
simplenet = SimpleNet()
simplenet_state_dict = torch.load("tmp/simplenet")
simplenet.load_state_dict(simplenet_state_dict, strict=False)   

<All keys matched successfully>

In [74]:
simplenet.state_dict()

OrderedDict([('fc1.weight',
              tensor([[-7.6798e-03, -9.6895e-04,  9.2678e-03,  ...,  8.6839e-03,
                        7.6343e-03, -6.1496e-04],
                      [ 4.8398e-03,  6.5706e-03,  3.9849e-03,  ...,  3.6685e-03,
                        5.5833e-03,  8.1738e-03],
                      [-2.6032e-03,  4.9011e-03,  7.9491e-03,  ..., -5.2361e-03,
                        3.9874e-03, -4.5173e-03],
                      ...,
                      [ 2.8446e-03, -7.0135e-03,  4.9849e-03,  ...,  2.8595e-03,
                       -2.8795e-03, -9.3262e-03],
                      [-6.5447e-03,  2.4438e-03, -8.7639e-03,  ...,  2.7842e-05,
                       -1.3247e-03, -6.4040e-04],
                      [ 4.3212e-03,  2.5143e-03, -2.4288e-03,  ...,  4.0716e-03,
                        5.8293e-03, -4.0808e-04]])),
             ('fc1.bias',
              tensor([0.0022, 0.0028, 0.0052,  ..., 0.0008, 0.0076, 0.0013])),
             ('fc2.weight',
              tensor([[

In [75]:
torch.hub.list('pytorch/vision')

Downloading: "https://github.com/pytorch/vision/zipball/main" to /Users/saadnaeem/.cache/torch/hub/main.zip


['alexnet',
 'convnext_base',
 'convnext_large',
 'convnext_small',
 'convnext_tiny',
 'deeplabv3_mobilenet_v3_large',
 'deeplabv3_resnet101',
 'deeplabv3_resnet50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b3',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_v2_l',
 'efficientnet_v2_m',
 'efficientnet_v2_s',
 'fcn_resnet101',
 'fcn_resnet50',
 'get_model_weights',
 'get_weight',
 'googlenet',
 'inception_v3',
 'lraspp_mobilenet_v3_large',
 'maxvit_t',
 'mc3_18',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'mvit_v1_b',
 'mvit_v2_s',
 'r2plus1d_18',
 'r3d_18',
 'raft_large',
 'raft_small',
 'regnet_x_16gf',
 'regnet_x_1_6gf',
 'regnet_x_32gf',
 'regnet_x_3_2gf',
 'regnet_x_400mf',
 'regnet_x_800mf',
 'regnet_x_8gf',
 'regnet_y_128gf',
 'regnet_y_16gf',
 'regnet_y_1