# Introduction
This notebook will attempt to adapt an existing CNN model to use LoRA for finetuning. 

Model choice - [MobileNetV2 trained on ImageNet1k](https://pytorch.org/vision/stable/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights) Reason: smallest classification CNN available via torchvision (by number of parameters), which means I can fine tune relatively faster, and there shouldn't be anything fundementally different with larger models.

Original Dataset - ImageNet1k_V2

Finetuning Dataset - [FGVCAIRCRAFT](https://pytorch.org/vision/stable/generated/torchvision.datasets.FGVCAircraft.html#torchvision.datasets.FGVCAircraft) (The dataset contains 10,000 images of aircrafts across 30 different manufacturers) Reason: Relatively small dataset (thus faster training and more representative of realworld finetuning task) and looking at airplanes is fun!

Ideas: try varying the amount of training data? how well does finetuning work on small data?

In [146]:
import lightning as L

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision.datasets import FGVCAircraft

from torchmetrics import Accuracy

import loralib as lora

import os

# make results reproducible
L.seed_everything(42)

Global seed set to 42


42

In [147]:
PATH_DATASETS = '/home/sunil/cnn-lora/data'

In [150]:
class LitCNN(L.LightningModule):
    
    def __init__(self, lora_on=False, data_dir=PATH_DATASETS):
        '''
        :lora: flag - if true uses LoRA
        :data_dir: - string path to where to store data
        '''
        super().__init__()

        # some hyperparameters
        self.lr = 1e-4
        self.batch_size = 4
        
        # dataset specific information
        self.data_dir = data_dir
        self.num_classes = 30 # there are 30 manufacturers in our dataset 
        
        # define lore hyper params
        self.lora_on = lora_on
        lora_rank = 8

        # define the model
        self.model = torchvision.models.mobilenet_v2()

        # reset the mlp head regardless of if we are using LoRA or not 
        self.model.classifier[1] = nn.Linear(1280, self.num_classes)
        nn.init.normal_(self.model.classifier[1].weight, 0, 0.01)
        nn.init.zeros_(self.model.classifier[1].bias)

        # setup model for lora if desired
        if self.lora_on:
            # replace all conv layers with lora conv layers
            for name, module in self.model.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    in_channels = module.in_channels
                    out_channels = module.out_channels
                    kernel_size = module.kernel_size[0]
                    stride = module.stride
                    padding = module.padding
                    dialtion = module.dilation
                    groups = module.groups
                    bias = module.bias if module.bias is not None else False
                    padding_mode = module.padding_mode

                    new_conv = lora.Conv2d(in_channels=in_channels, 
                                         out_channels=out_channels,
                                         kernel_size=kernel_size,
                                         stride=stride,
                                         padding=padding,
                                         dilation=dialtion,
                                         groups=groups,
                                         bias=bias,
                                         padding_mode=padding_mode,
                                         r=lora_rank
                                         )


                    # `parts` is essentially a list of keys we can use to determine where this layer belongs
                    parts = name.split('.')
                    self.set_model_feature(parts, new_conv)

            # setup layer freezes properly for lora
            # TODO is this fixing our new mlp head too?
            lora.mark_only_lora_as_trainable(self.model)
            
            # load in pretrained weights
            # TODO: remove mlp head weights from the state dict
            self.model.load_state_dict(torch.load(self.get_initial_state_dict_path()), strict=False)
            
        
            
        # define metrics
        self.val_accuray = Accuracy(task="multiclass", num_classes=self.num_classes)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=self.num_classes)
        
    # helper function to update layers with lora layers during init
    def set_model_feature(self, keys, value):
        pointer = self.model
        while len(keys) > 1:
            key = keys.pop(0)
            pointer = getattr(pointer, key)

        key = keys.pop(0)
        pointer[int(key)] = value

    def get_initial_state_dict_path(self):
        '''
        saves the imagenet 1k weights to a state dict on disk (if necessary) and returns path to where they are saved
        '''
        weights_path = 'mobilenet_v2_imagenet1k_state_dict.pt'

        if not os.path.isfile(weights_path):
            m = torchvision.models.mobilenet_v2(weights=torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT)
            torch.save(m.state_dict(), weights_path)

        return weights_path
        
    
    def forward(self, x):
        # TODO: any transforms?
        x = self.model(x)
        return F.log_softmax(x,1)

    # centralize stuff we need in train/val/test
    def common_step(self, batch, batch_idx):
        x,y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return x, y, logits, loss

    def training_step(self, batch, batch_idx):
        _, _, _, loss = self.common_step(batch, batch_idx)
        self.log("train_loss", loss, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        _, y, logits, loss = self.common_step(batch, batch_idx)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)
    
    def test_step(self, batch, batch_idx):
        _, y, logits, loss = self.common_step(batch, batch_idx)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer


    def prepare_data(self):
        # download data and set transforms

        #TODO: transforms!
        FGVCAircraft(self.data_dir, split='train', annotation_level='manufacturer', download=True) 
        FGVCAircraft(self.data_dir, split='val', annotation_level='manufacturer', download=True)
        FGVCAircraft(self.data_dir, split='test', annotation_level='manufacturer', download=True)
        
        

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.aircraft_train = FGVCAircraft(self.data_dir, split='train', annotation_level='manufacturer')
            self.aircraft_val = FGVCAircraft(self.data_dir, split='val', annotation_level='manufacturer')
            
        if stage == 'test' or stage is None:
            self.aircraft_test = FGVCAircraft(self.data_dir, split='test', annotation_level='manufacturer')


    
            
    def train_dataloader(self):
        loader = DataLoader(self.aircraft_train, batch_size=self.batch_size, num_workers=16)

    def val_dataloader(self):
        loader = DataLoader(self.aircraft_val, batch_size=self.batch_size, num_workers=16)

    def train_dataloader(self):
        loader = DataLoader(self.aircraft_test, batch_size=self.batch_size, num_workers=16)
        

In [151]:
m = LitCNN(lora_on=True)
model = m.model

RuntimeError: Error(s) in loading state_dict for MobileNetV2:
	size mismatch for classifier.1.weight: copying a param with shape torch.Size([1000, 1280]) from checkpoint, the shape in current model is torch.Size([30, 1280]).
	size mismatch for classifier.1.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([30]).