# Introduction
This notebook will attempt to adapt an existing CNN model to use LoRA for finetuning. 

Model choice - [MobileNetV2 trained on ImageNet1k](https://pytorch.org/vision/stable/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights) Reason: smallest classification CNN available via torchvision (by number of parameters), which means I can fine tune relatively faster, and there shouldn't be anything fundementally different with larger models.

Original Dataset - ImageNet1k_V2

Finetuning Dataset - [FGVCAIRCRAFT](https://pytorch.org/vision/stable/generated/torchvision.datasets.FGVCAircraft.html#torchvision.datasets.FGVCAircraft) (The dataset contains 10,000 images of aircraft, with 100 images for each of 100 different aircraft model variants, most of which are airplanes. Aircraft models are organized in a three-levels hierarchy.) Reason: Relatively small dataset (thus faster training and more representative of realworld finetuning task) and looking at airplanes is fun!


Ideas: try varying the amount of training data? how well does finetuning work on small data?

In [36]:
import lightning as L
import torch
import torchvision
import loralib as lora

# make results reproducible
L.seed_everything(42)

Global seed set to 42


42

In [97]:
class LitCNNLoRA(L.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = torchvision.models.mobilenet_v2(weights=torchvision.models.mobilenetv2.MobileNet_V2_Weights.DEFAULT)

        # pull out all of the convolutional models in 
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                in_channels = module.in_channels
                out_channels = module.out_channels
                kernel_size = module.kernel_size
                stride = module.stride
                padding = module.padding
                dialtion = module.dilation
                groups = module.groups
                bias = module.bias if module.bias is not None else False
                padding_mode = module.padding_mode
                
        print(counter)
                
        
                

        

In [98]:
m = LitCNNLoRA()
model = m.model


3 32 (3, 3) (2, 2) (1, 1) (1, 1) 1 None zeros
32 32 (3, 3) (1, 1) (1, 1) (1, 1) 32 None zeros
32 16 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
16 96 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
96 96 (3, 3) (2, 2) (1, 1) (1, 1) 96 None zeros
96 24 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
24 144 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
144 144 (3, 3) (1, 1) (1, 1) (1, 1) 144 None zeros
144 24 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
24 144 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
144 144 (3, 3) (2, 2) (1, 1) (1, 1) 144 None zeros
144 32 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
32 192 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
192 192 (3, 3) (1, 1) (1, 1) (1, 1) 192 None zeros
192 32 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
32 192 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
192 192 (3, 3) (1, 1) (1, 1) (1, 1) 192 None zeros
192 32 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
32 192 (1, 1) (1, 1) (0, 0) (1, 1) 1 None zeros
192 192 (3, 3) (2, 2) (1, 1) (1, 1) 192 None zeros
192 64 (1, 1) (1, 1) (0, 0) (1

In [78]:
count = 0
for name, module in model.named_modules():
    #print(name, module)
    if isinstance(module, torch.nn.Conv2d):
        count +=1
print(count)

<class 'torchvision.models.mobilenetv2.MobileNetV2'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torchvision.ops.misc.Conv2dNormActivation'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torch.nn.modules.activation.ReLU6'>
<class 'torchvision.models.mobilenetv2.InvertedResidual'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torchvision.ops.misc.Conv2dNormActivation'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torch.nn.modules.activation.ReLU6'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torchvision.models.mobilenetv2.InvertedResidual'>
<class 'torch.nn.modules.container.Sequential'>
<class 'torchvision.ops.misc.Conv2dNormActivation'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
<class 'torch.nn.modules.activation.ReLU6'>
<class 'torchvision.ops.misc.Conv2dNormActivati