## Lab 5 (Part 1) - Low Rank Adaptation
---
In this lab, we will implement low-rank adaptation from scratch for one of the Linear (fully connected feedforward) layers of an image classification neural network. The underlying mechanism is much the same for LLMs; however a small image classifier is more tractable to work with in class!

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

### MNIST Digit Recognition 
Let us first build a convolutional neural network to classify images from the MNIST handwritten digits dataset.

In your network, the architecture should include two convolutional (Conv2d) layers, each followed by a relu activation. After *both* of these pairs of operations are completed, use a max-pooling (max_pool2d), followed by a dropout (dropout1). Then flatten the data, and pass it through two fully connected layers, with the first of the two having 2048 output neurons, and the second having 10 output neurons for our classification. Finally, return the log softmax of the outputs.

Feel free to experiment with layer sizes, kernel sizes, etc.

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        ''' YOUR CODE HERE'''

    def forward(self, x):
        ''' YOUR CODE HERE'''

The training and testing functions are provided below for you. We use the negative log likelihood loss in this implementation.
https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html

Feel free to make changes if necessary.

In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

Next, we check for CUDA/MPS availability, and set the appropriate device as default. We then normalize all images as 

output[channel] = (input[channel] - mean[channel]) / std[channel]

After initializing the dataloaders, we then create an instance of our model, set the optimizer, set up a learning rate scheduler, and call the training/testing functions for a few epochs. Finally, we save the model weights for later use.

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Known mean and std dev of MNIST Dataset
    ])

dataset1 = datasets.MNIST('../data', train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                   transform=transform)

train_loader = torch.utils.data.DataLoader(dataset1,batch_size=64)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=64)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=0.01)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(1, 10 + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()


torch.save(model.state_dict(), "mnist_cnn.pt")

### Fine Tuning for a different dataset
---
Now imagine we wish to use the same neural network, but on the [FashionMNIST dataset](https://github.com/zalandoresearch/fashion-mnist) instead, but to also retain the ability to classify MNIST digits that we have just gained. 

Using low-rank adaptation (LoRA) serves a dual purpose; first, we can learn a new set of weights, which act as a separate 'add-on' for the new task, and second, we can get away with learning fewer weights instead of fine-tuning the entire MNIST model. To get a sense of the compute savings, print out the number of parameters in the original model in the code block below:

In [None]:
'''YOUR CODE HERE'''

Now, let us load the new dataset, and the corresponding dataloaders.

In [None]:
dataset1 = datasets.FashionMNIST('../data', train=True, download=True,
                   transform=transform)
dataset2 = datasets.FashionMNIST('../data', train=False,
                   transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

LoRA is typically used with fully connected layers (or self-attention layers, where everything is interconnected). In our case, we will use LoRA as an additive set of weights to the *first fully connected layer* in our neural network. Let the weights of this layer be represented by the matrix W.

We need to find a proxy set of matrices A and B, such that $\Delta$W = AB (approximately). If W has dimension m X n, then A has dimension m X r, and B has dimension r X n, where r is the rank of the weight matrix W. Let us start by calculating the rank of the matrix W below. Experiment with the 'rtol' argument to see how the rank changes.

In [None]:
'''YOUR CODE HERE'''

Now initialize A and B to be matrices of the appropriate size in the block below, ensuring that A contains random values sampled from a normal distribution with a mean of 0 and a variance of 1, and B contains all zeros. Make sure their requires_grad flag is set to True.

In [None]:
'''YOUR CODE HERE'''

How many parameters are we learning with this approach? How does this compare to the original linear layer's number of weights?

In [None]:
'''YOUR CODE HERE'''

Now, before we can use these new weight matrices, we need to add some functionality to our model definition to enable this! We need to obtain the model's values just before the first fully connected layer (so that we can pass them through our new weights separately), as well as isolate the remainder of the network to use with the combination -> W * values + B * (A * values). Copy the constructor and forward pass definitions from earlier in the lab, and complete the two new functions below:

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        ''' YOUR CODE HERE'''

    def forward(self, x):
        ''' YOUR CODE HERE'''

    def lora_inputs(self, x):
        '''YOUR CODE HERE'''

    def remaining_forward_pass(self,x):
        '''YOUR CODE HERE'''

Before we go ahead and train our new weights, let us check how well the model trained on MNIST digits performs on the FashionMNIST classification task:

#model = Net().to(device)
#model.load_state_dict(torch.load("mnist_cnn.pt"))
'''YOUR CODE HERE'''

Not great! Let's finally go ahead and learn A and B with a few epochs of tuning over the new dataset. First, freeze the original model's weights, since we are not going to change them.

In [None]:
'''YOUR CODE HERE'''

Modify the training logic from the previous given functions to fill in the following, but now including A and B as well to get the model's outputs, using the new functions we added to our model definition. 

In [None]:
def train_lora(model, device, train_loader, optimizer, epoch, A, B):
    '''YOUR CODE HERE'''

def test_lora(model, device, test_loader, A, B):
    '''YOUR CODE HERE'''

Let's now test how well we did!

In [None]:
optimizer = optim.Adadelta([A,B], lr=0.01)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

for epoch in range(1,3):
    train_lora(model, device, train_loader, optimizer, epoch, A, B)
    test_lora(model, device, test_loader, A, B)
    scheduler.step()

As you can see, we were able to make huge improvements from where we started, with a much smaller number of parameters needing to be trained - not only in comparison to the entire original network, but also in comparison to just the first linear layer. In LLMs, as you can probably imagine, the number of weights for the self-attention layer is often in the tens of thousands, if not more. LoRA often allows fine-tuning on limited-hardware where a full fine-tuning pipeline may be infeasible.