In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.10MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 133kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.62MB/s]


In [4]:
class DigiNet(nn.Module):
  def __init__(self, HiddenLayer1=1500):
    super(DigiNet, self).__init__()
    self.linear1=nn.Linear(28*28, HiddenLayer1)
    self.linear2=nn.Linear(HiddenLayer1, 10)
    self.relu=nn.ReLU()

  def forward(self, img):
    x=img.view(-1, 28*28)
    x=self.relu(self.linear1(x))
    x=self.linear2(x)
    return x

net=DigiNet().to(device)

In [5]:
def train(train_loader, net, epochs=7):
  CELoss=nn.CrossEntropyLoss()
  optim=torch.optim.Adam(net.parameters(), lr=0.001)
  total_loss=0
  num_iterations=0

  for epoch in range(epochs):
    net.train() #enable dropout and training
    data_iterator=tqdm(train_loader, desc=f'Epoch {epoch +1}')
    for data in data_iterator:
      num_iterations+=1
      x, y=data
      x=x.to(device)
      y=y.to(device)
      optim.zero_grad()
      output=net(x.view(-1, 28*28))
      loss=CELoss(output, y)
      total_loss+=loss.item()
      avg_loss=total_loss/num_iterations
      data_iterator.set_postfix(loss=avg_loss)
      loss.backward()
      optim.step()

train(train_loader, net, epochs=1)


Epoch 1: 100%|██████████| 6000/6000 [00:32<00:00, 182.14it/s, loss=0.198]


In [6]:
def test():
  correct=0
  total=0

  wrong_counts=[0 for i in range(10)]
  with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
      x,y =data
      x=x.to(device)
      y=y.to(device)
      output=net(x.view(-1, 28*28))

      for idx, i in enumerate(output):
        if torch.argmax(i)==y[idx]:
          correct+=1
        else:
          wrong_counts[y[idx]]+=1
        total+=1
    print(f'Accuracy: {round(correct/total, 3)}')

    for i in range(len(wrong_counts)):
        print(f'wrong counts for digit {i}: {wrong_counts[i]}')
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 319.45it/s]

Accuracy: 0.959
wrong counts for digit 0: 8
wrong counts for digit 1: 11
wrong counts for digit 2: 36
wrong counts for digit 3: 40
wrong counts for digit 4: 63
wrong counts for digit 5: 27
wrong counts for digit 6: 20
wrong counts for digit 7: 85
wrong counts for digit 8: 66
wrong counts for digit 9: 52





In [7]:
total_original_parameter=0
for index, layer in enumerate([net.linear1, net.linear2]):
  total_original_parameter+=layer.weight.nelement()+ layer.bias.nelement()
  print(f'Layer {index+1}: W: {layer.weight.shape} + B:{layer.bias.shape}')
print(f'Total parameters: {total_original_parameter}')


Layer 1: W: torch.Size([1500, 784]) + B:torch.Size([1500])
Layer 2: W: torch.Size([10, 1500]) + B:torch.Size([10])
Total parameters: 1192510


## Lora Parameterization:

In [None]:
class LoRAParameterization():
  def __init__(self, in_features, out_features,rank, alpha):
    self.LoRA_A=nn.Parameter(torch.zeros((rank, in_features)).to(device))
    self.LoRA_B=nn.Parameter(torch.zeros((out_featurres, rank)).to(device))
    nn.init.normal_(self.LoRA_A, mean=0, std=1)
    self.scale=alpha/rank
    self.enabled=True

  def forward(self, original_weights):
    if self.enabled:
      return original_weights + torch.matmul(self.LoRA_B, self.LoRA_A).view(original_weights.shape)*self.scale
    else:
      return original_weights
