

DONE:

-Create a dataset class for your own dataset

-Create a network class that wraps a pretrained resnet

-Implement Unfreezing on the network (I trained it earlier, but restarted it and then interrupted it so it removed the prevoius output - so I attached a screenshot below)

![alt text](https://thleats-bucket.s3.us-east-2.amazonaws.com/lab10.JPG)

In [0]:


from torchvision.models import resnet152
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import torch
from torch import optim, nn
import zipfile
from google.colab import files
import os
import sys
from PIL import Image, ImageOps
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
from os import path
import urllib.request
from zipfile import ZipFile
import pdb

In [0]:
class birddataset(Dataset):
  def __init__(self,Train=True,img_size=256):
    super(birddataset,self).__init__()


#download the data, and change the filepath
    url='https://thleats-bucket.s3.us-east-2.amazonaws.com/bird-species-classification.zip'
    location = '/content/bird-species-classification.zip'


    if path.exists(location):
      print('already downloaded!')
    else:
      print('downloading')
      urllib.request.urlretrieve(url,location)
    # Create a ZipFile Object and load sample.zip in it
      with ZipFile(location, 'r') as zipObj:
        # Extract all the contents of zip file in current directory
        zipObj.extractall('/content/dataset/')


    if Train==False:
      self.data=datasets.ImageFolder(root='/content/dataset/test_data/test_data/',
                                        transform=transforms.Compose([transforms.Resize(img_size),
                                        transforms.CenterCrop(img_size),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), 
                                        (0.5, 0.5, 0.5)),]))
    else:
      self.data=datasets.ImageFolder(root='/content/dataset/train_data/train_data/',
                                      transform=transforms.Compose([transforms.Resize(img_size),
                                      transforms.CenterCrop(img_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), 
                                      (0.5, 0.5, 0.5)),]))

    def __getitem__(self,i):
      return self.data[i][0]

    def __len__(self):
      return len(self.data)

bird_data_train = birddataset(Train=True)
bird_data_test = birddataset(Train=False)

downloading
already downloaded!


In [0]:
class ResNetBirds(nn.Module):
    def __init__(self, num_classes, start_frozen=False):
        super(ResNetBirds, self).__init__()
        model = resnet152(pretrained=True,progress=True)
        # Part 1.2
        # Load the model - make sure it is pre-trained

        # Part 1.4
        if start_frozen:
            model.zero_grad()

        for name,module in model.named_children():
          if name == 'fc':
            #pdb.set_trace()
            module.requires_grad=True
            module.out_features = num_classes
            #pdb.set_trace()
        self.model=model
        # Part 1.2
        # Look at the code of torchvision.models.resnet152 to find the name of the attribute to override (the last layer of the resnet)
        # Override the last layer of the neural network to map to the correct number of classes. Note that this new layer has requires_grad = True

        
    def unfreeze(self, n_layers):
        # Part 1.4
        # Turn on gradients for the last n_layers
        model=self.model
        layers=[]
        for name,module in model.named_children():
          layers.append(name)
            #pdb.set_trace()
        reversed_layers = layers[::-1]

        for i in range(n_layers):
            for name,module in model.named_children():
              if name == reversed_layers[i]:
                #pdb.set_trace()
                module.requires_grad=True
                #pdb.set_trace()


    def forward(self, x):
        # Part 1.2
        # Pass x through the resnet
        #pass
        return(self.model(x))



classes=max(bird_data_train.data.targets)+1
model = ResNetBirds(classes,start_frozen=True)

Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /root/.cache/torch/checkpoints/resnet152-b121ed2d.pth
100%|██████████| 230M/230M [00:03<00:00, 71.8MB/s]


In [0]:
def accuracy(y_hat, y_truth):
    """Gets average accuracy of a vector of predictions"""
    
    preds = torch.argmax(y_hat, dim=1)
    acc = torch.mean((preds == y_truth).float())
    return acc

def evaluate(model, objective, val_loader, device):
    """Gets average accuracy and loss for the validation set"""

    val_losses = []
    val_accs = []
    # model.eval() so that batchnorm and dropout work in eval mode
    model.eval()
    # torch.no_grad() to turn off computation graph creation. This allows for temporal
    # and spatial complexity improvements, which allows for larger validation batch 
    # sizes so it’s recommended
    with torch.no_grad():
        for x, y_truth in val_loader:

            x, y_truth = x.to(device), y_truth.to(device)
            y_hat = model(x)
            val_loss = objective(y_hat, y_truth)
            val_acc = accuracy(y_hat, y_truth)

            val_losses.append(val_loss.item())
            val_accs.append(val_acc)

    model.train()

    return torch.mean(torch.Tensor(val_losses)), torch.mean(torch.Tensor(val_accs))

In [0]:
def train(start_frozen=False, model_unfreeze=0):
    """Fine-tunes a CNN
    Args:
        start_frozen (bool): whether to start with the network weights frozen.
        model_unfreeze (int): the maximum number of network layers to unfreeze
    """
    epochs = 20
    # Start with a very low learning rate
    lr = .00005
    val_every = 3
    num_classes = 16
    batch_size = 32
    device = torch.device('cuda:0')

    # Data
    train_dataset = birddataset(Train=True).data
    val_dataset = birddataset(Train=False).data
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              num_workers=8,
                              batch_size=batch_size)
    val_loader = DataLoader(val_dataset,
                              shuffle=True,
                              num_workers=8,
                              batch_size=batch_size)
    
    # Model
    model = ResNetBirds(num_classes, start_frozen=start_frozen).cuda()
    
    # Objective
    objective = nn.CrossEntropyLoss()
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-1)

    # Progress bar
    pbar = tqdm(total=len(train_loader) * epochs)

    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    
    cnt = 0
    for epoch in range(epochs):

        # Implement model unfreezing
        if epoch < model_unfreeze:
            # Part 1.4
            # Unfreeze the last layers, one more each epoch
            model.unfreeze(epoch)
        
        for x, y_truth in train_loader:
        
            x, y_truth = x.cuda(), y_truth.cuda()

            optimizer.zero_grad()

            y_hat = model(x)
            train_loss = objective(y_hat, y_truth)
            train_acc = accuracy(y_hat, y_truth)

            train_loss.backward()
            optimizer.step()

            train_accs.append(train_acc)
            train_losses.append(train_loss.item())

            if cnt % val_every == 0:
                val_loss, val_acc = evaluate(model, objective, val_loader, device)
                val_losses.append(val_loss)
                val_accs.append(val_acc)

            pbar.set_description('train loss:{:.4f}, train accuracy:{:.4f}.'.format(train_loss.item(), train_acc))
            pbar.update(1)
            cnt += 1

    pbar.close()
    plt.subplot(121)
    plt.plot(np.arange(len(train_accs)), train_accs, label='Train Accuracy')
    plt.plot(np.arange(len(train_accs), step=val_every), val_accs, label='Val Accuracy')
    plt.legend()
    plt.subplot(122)
    plt.plot(np.arange(len(train_losses)), train_losses, label='Train Loss')
    plt.plot(np.arange(len(train_losses), step=val_every), val_losses, label='Val Loss')
    plt.legend()
    plt.show()

In [0]:
train(start_frozen=True, model_unfreeze=4)

already downloaded!
already downloaded!


train loss:0.9634, train accuracy:0.9091.:  15%|█▌        | 15/100 [05:00<17:34, 12.41s/it]

KeyboardInterrupt: ignored