In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

In [2]:
# Loading the model
class FeedForwardNetwork(torch.nn.Module):
    def __init__(self, in_dim, embedding_dim=128, out_dim=10):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(in_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, out_dim),
        )

    def forward(self, x):
        x = torch.flatten(x)
        return self.linear(x)
    
model = FeedForwardNetwork(in_dim=784, out_dim=10, embedding_dim=128)
path = "./model/two_layer_linear_model.pth"
model.load_state_dict(torch.load(path))
(w1,b1,w2,b2) = model.parameters()

# Freezing weight
for param in model.parameters():
    param.requires_grad = False

  model.load_state_dict(torch.load(path))


In [3]:
# Load  the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor()
])
full_data = datasets.MNIST(root='./data', transform=transform, download=True)
split_percent = int(0.25 * len(full_data))
train_data1, test_data1, train_data2, test_data2 = random_split(full_data, [split_percent, split_percent, split_percent,
                                                                            split_percent])
train_data2 = DataLoader(dataset=train_data2, shuffle=True)
test_loader2 = DataLoader(dataset=test_data2, shuffle=True)
in_dim = (train_data2.dataset[0][0].size()[1]) ** 2


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5714039.48it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 462150.65it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4777682.56it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1515796.37it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [4]:
class NN_LoRA_layer(torch.nn.Module):
    def __init__(self, original_model, in_dim, out_dim, rank=4, alpha=1):
        super().__init__()
        self.original_model = original_model
        self.A = torch.nn.Parameter(torch.rand(in_dim, rank), requires_grad=True)
        print(self.A)
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim), requires_grad=True)
        print("B")
        print(self.B)
        self.alpha = alpha

    def forward(self, x):
        output1 = self.original_model(x)
        # LORA 
        x = torch.flatten(x)
        output2 = self.alpha * (x @ self.A @ self.B)
        return output1 + output2


# Model creation
lora_model = NN_LoRA_layer(original_model= model, in_dim=in_dim, out_dim=10)
optimizer = torch.optim.SGD(lora_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()


# Training
for epoch in range(5):
    loss = 0
    for images, labels in train_data2.dataset:
        optimizer.zero_grad()
        output = lora_model.forward(images)
        loss = criterion(output, torch.tensor(labels))
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')
print("Finished Training!")


Parameter containing:
tensor([[0.3193, 0.7511, 0.7246, 0.0439],
        [0.5606, 0.7165, 0.4825, 0.4677],
        [0.3309, 0.1528, 0.3955, 0.1771],
        ...,
        [0.6981, 0.6977, 0.3373, 0.5016],
        [0.9370, 0.8224, 0.4838, 0.1313],
        [0.0512, 0.0374, 0.9726, 0.6047]], requires_grad=True)
B
Parameter containing:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], requires_grad=True)
Epoch: 1, Loss: 0.0002
Epoch: 2, Loss: 0.0002
Epoch: 3, Loss: 0.0002
Epoch: 4, Loss: 0.0002
Epoch: 5, Loss: 0.0002
Finished Training!


In [5]:
# Prediction using created model.
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader2.dataset:
        outputs = lora_model.forward(images)
        predicted = torch.argmax(outputs)
        total += 1
        if predicted == labels:
            correct += 1
    print(f'Accuracy: {(correct / total) * 100}%')

Accuracy: 98.56666666666666%


In [6]:
lora_model.state_dict()

OrderedDict([('A',
              tensor([[0.3193, 0.7511, 0.7246, 0.0439],
                      [0.5606, 0.7165, 0.4825, 0.4677],
                      [0.3309, 0.1528, 0.3955, 0.1771],
                      ...,
                      [0.6981, 0.6977, 0.3373, 0.5016],
                      [0.9370, 0.8224, 0.4838, 0.1313],
                      [0.0512, 0.0374, 0.9726, 0.6047]])),
             ('B',
              tensor([[-0.1666,  0.1460,  0.0210,  0.1407, -0.2974,  0.0495,  0.0820, -0.0860,
                       -0.2111,  0.3220],
                      [-0.0098,  0.2407, -0.1892, -0.1981,  0.2429, -0.1438,  0.0071,  0.1365,
                        0.0014, -0.0876],
                      [ 0.1027, -0.2304,  0.1188, -0.0760,  0.2408,  0.1243, -0.0160, -0.1987,
                        0.0144, -0.0799],
                      [ 0.4395, -0.1162, -0.0628,  0.0967, -0.1426, -0.0964,  0.1148,  0.0774,
                       -0.1670, -0.1435]])),
             ('original_model.linear.0.weight