In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

In [2]:
# Loading the model
class FeedForwardNetwork(torch.nn.Module):
    def __init__(self, in_dim, embedding_dim=128, out_dim=10):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(in_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, out_dim),
        )

    def forward(self, x):
        x = torch.flatten(x)
        return self.linear(x)
    
model = FeedForwardNetwork(in_dim=784, out_dim=10, embedding_dim=128)
path = "./model/two_layer_linear_model.pth"
model.load_state_dict(torch.load(path))
(w1,b1,w2,b2) = model.parameters()

# Freezing weight
for param in model.parameters():
    param.requires_grad = False

  model.load_state_dict(torch.load(path))


In [3]:
# Load  the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor()
])
full_data = datasets.MNIST(root='./data', transform=transform, download=True)
split_percent = int(0.25 * len(full_data))
train_data1, test_data1, train_data2, test_data2 = random_split(full_data, [split_percent, split_percent, split_percent,
                                                                            split_percent])
train_data2 = DataLoader(dataset=train_data2, shuffle=True)
test_loader2 = DataLoader(dataset=test_data2, shuffle=True)
in_dim = (train_data2.dataset[0][0].size()[1]) ** 2


In [4]:
class NN_LoRA_layer(torch.nn.Module):
    def __init__(self, original_model, in_dim, out_dim, rank=4, alpha=1):
        super().__init__()
        self.original_model = original_model
        self.A = torch.nn.Parameter(torch.rand(in_dim, rank), requires_grad=True)
        print(self.A)
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim), requires_grad=True)
        print("B")
        print(self.B)
        self.alpha = alpha

    def forward(self, x):
        output1 = self.original_model(x)
        # LORA 
        x = torch.flatten(x)
        output2 = self.alpha * (x @ self.A @ self.B)
        return output1 + output2


# Model creation
lora_model = NN_LoRA_layer(original_model= model, in_dim=in_dim, out_dim=10)
optimizer = torch.optim.SGD(lora_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()


# Training
for epoch in range(5):
    loss = 0
    for images, labels in train_data2.dataset:
        optimizer.zero_grad()
        output = lora_model.forward(images)
        loss = criterion(output, torch.tensor(labels))
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')
print("Finished Training!")


Parameter containing:
tensor([[0.1003, 0.8953, 0.3099, 0.8571],
        [0.6511, 0.8674, 0.5029, 0.2171],
        [0.5526, 0.5707, 0.4473, 0.4281],
        ...,
        [0.5971, 0.9654, 0.6815, 0.2721],
        [0.0553, 0.8235, 0.5654, 0.5020],
        [0.7317, 0.2869, 0.7150, 0.0820]], requires_grad=True)
B
Parameter containing:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], requires_grad=True)
Epoch: 1, Loss: 0.0010
Epoch: 2, Loss: 0.0012
Epoch: 3, Loss: 0.0011
Epoch: 4, Loss: 0.0009
Epoch: 5, Loss: 0.0007
Finished Training!


In [5]:
# Prediction using created model.
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader2.dataset:
        outputs = lora_model.forward(images)
        predicted = torch.argmax(outputs)
        total += 1
        if predicted == labels:
            correct += 1
    print(f'Accuracy: {(correct / total) * 100}%')

Accuracy: 98.78%


In [6]:
lora_model.state_dict()

OrderedDict([('A',
              tensor([[0.1003, 0.8953, 0.3099, 0.8571],
                      [0.6511, 0.8674, 0.5029, 0.2171],
                      [0.5526, 0.5707, 0.4473, 0.4281],
                      ...,
                      [0.5971, 0.9654, 0.6815, 0.2721],
                      [0.0553, 0.8235, 0.5654, 0.5020],
                      [0.7317, 0.2869, 0.7150, 0.0820]])),
             ('B',
              tensor([[ 0.1827, -0.2141,  0.1311, -0.1663,  0.0314, -0.2191, -0.1150,  0.1282,
                        0.2670, -0.0259],
                      [-0.0154,  0.0873, -0.0267, -0.1106, -0.0615,  0.2729,  0.1619,  0.0656,
                       -0.1395, -0.2341],
                      [ 0.1395, -0.1523, -0.0575,  0.2883, -0.2756, -0.0160, -0.0550,  0.0310,
                       -0.1087,  0.2063],
                      [-0.2355,  0.0183, -0.1001, -0.0234,  0.3345, -0.0789,  0.1689, -0.1382,
                        0.0313,  0.0230]])),
             ('original_model.linear.0.weight