In [2]:
#!pip install datasets

In [3]:
## Please install torch and datasets
import torch
from torchvision.transforms import functional as t
import torch.nn.functional as f
from datasets import load_dataset
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [4]:
## Loading our dataset
ds = load_dataset("ylecun/mnist")

In [5]:
## Data splits

X_train_p = ds["train"]["image"]
Y_train = ds["train"]["label"]
X_test_p = ds["test"]["image"]
Y_test = ds["test"]["label"]

In [6]:
## PIL to Tensors

X_train = [t.pil_to_tensor(x) for x in X_train_p]
X_test = [t.pil_to_tensor(x) for x in X_test_p]
X_train = torch.stack(X_train).to(device)
X_test = torch.stack(X_test).to(device)
print(X_train.shape, X_test.shape)

torch.Size([60000, 1, 28, 28]) torch.Size([10000, 1, 28, 28])


In [7]:
## Fixing the shape

X_train = X_train.view(-1, 28, 28)
X_test = X_test.view(-1, 28, 28)
print(X_train.shape, X_test.shape)

torch.Size([60000, 28, 28]) torch.Size([10000, 28, 28])


In [8]:
## Making labels into tensors

Y_train = torch.tensor(Y_train).to(device)
Y_test = torch.tensor(Y_test).to(device)

In [9]:
## Flattening the image as DNN takes flat tensor as input

X_train = X_train.view(-1, 784).float() / 255.0
X_test = X_test.view(-1, 784).float() / 255.0
print(X_train.shape, X_test.shape)

torch.Size([60000, 784]) torch.Size([10000, 784])


In [10]:

import numpy as np
from numpy.linalg import svd



class Linear():
  def __init__(self, input_dims, output_dims, B=True, last=False):
    self.training = True
    self.W = (torch.randn(input_dims, output_dims) * (5/3) / (input_dims**0.5)).to(device) if not last else (torch.randn(input_dims, output_dims) * (5/3) / (input_dims**0.5) * 0.1).to(device)
    if B: self.B = torch.randn(output_dims).to(device) if not last else (torch.randn(output_dims) * 0.1).to(device)
    else: self.B = torch.tensor([]).to(device)

  def __call__(self, x):
    if not torch.equal(self.B, torch.tensor([]).to(device)): self.result = x@self.W + self.B
    else: self.result = x@self.W
    return self.result

  def parameters(self):
    return [self.W] + [self.B]


class Tanh():
  def __init__(self):
    self.training = True
    return None

  def __call__(self, x):
    self.result = torch.tanh(x)
    return self.result

  def parameters(self):
    return []

class Dropout():
  def __init__(self, batch_size, output_dims, rate=0.9):
    self.training = True
    self.rate = rate
    self.factor = (torch.rand(batch_size, output_dims) < self.rate).int().to(device)
    return None

  def __call__(self, x):
    if self.training: self.result = x * self.factor
    else: self.result = x
    return self.result

  def parameters(self):
    return []


def decomposition(A, k=1):

    # SVD
    U, S, VT = svd(A, full_matrices=False)


    # Truncate to rank-k
    U_k = U[:, :k]                # (784 x k)
    S_k = np.diag(S[:k])          # (k x k)
    VT_k = VT[:k, :]              # (k x 10)

    # Factor A_k = L @ R, where L and R are low-rank factors
    sqrt_S_k = np.sqrt(S_k)       # (k x k)
    L = U_k @ sqrt_S_k            # (784 x k)
    R = sqrt_S_k @ VT_k           # (k x 10)

    L_flat = L.flatten()
    R_flat = R.flatten()
    LR_concat = np.concatenate([L_flat, R_flat])

    return LR_concat


In [21]:
n1 = 512
n2 = 256
n3 = 512
n4 = 794
batch_size = 1

updater = [
    Linear(794, n1), Tanh(), Dropout(batch_size, n1),
    #Linear(n1, n2), Tanh(), Dropout(batch_size, n2),
    #Linear(n2, n3), Tanh(), Dropout(batch_size, n3),
   Linear(n1, 794, last=True),
    
]

predictor = [
    Linear(784, 10, last=True, B=False)
]

updater_params = [p for layer in updater for p in layer.parameters()]
numparams = 0
for p in updater_params:
    p.requires_grad = True
    numparams += p.numel()


predictor_params = [p for layer in predictor for p in layer.parameters()]
for p in predictor_params:
    p.requires_grad = True
print(numparams)

814362


In [23]:
## Training loop for updater and predictor
iters = 4000  
alpha = 0.01

for c in range(iters):
    # Zero updater gradients
    for p in updater_params:
        if p.grad is not None:
            p.grad.zero_()

    # Decompose predictor weights
    with torch.no_grad():
        #current_W = predictor[0].W.detach().cpu().numpy()
        if predictor[0].W.grad != None: 
            current_W = predictor[0].W.grad.cpu().numpy()
        else:
            print("None update")
            current_W = torch.zeros((predictor[0].W.shape)).cpu().numpy()
        LR_concat = decomposition(current_W, k=1)
    
    # Forward pass through updater
    for layer in updater:
        layer.training = True
    updater_input = torch.tensor(LR_concat, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
    updater_output = updater_input
    for layer in updater:
        updater_output = layer(updater_output)
    L_update = updater_output[:, :784].reshape(784, 1)
    R_update = updater_output[:, 784:].reshape(1, 10)
    W_update = L_update @ R_update

    # Compute loss with updated weights
    for layer in predictor:
        layer.training = False
    updated_weights = predictor[0].W + W_update * 0.1
    predictions = X_train @ updated_weights 
    loss = f.cross_entropy(predictions, Y_train)
    
    # Backward and update
    loss.backward()
    with torch.no_grad():
        predictor[0].W += W_update.detach() * 0.1
    for p in updater_params:
        if p.grad is not None:
            p.data -= alpha * p.grad

    if c % 20 == 0:
        print(f"Iteration {c:4d}, Loss: {loss.item():.6f}")


Iteration    0, Loss: 1.831913
Iteration   20, Loss: 1.894346
Iteration   40, Loss: 1.910147
Iteration   60, Loss: 1.885522
Iteration   80, Loss: 1.856037
Iteration  100, Loss: 1.845506
Iteration  120, Loss: 1.807259
Iteration  140, Loss: 1.823778
Iteration  160, Loss: 1.835210
Iteration  180, Loss: 1.852521
Iteration  200, Loss: 1.868489
Iteration  220, Loss: 1.873372
Iteration  240, Loss: 1.870271
Iteration  260, Loss: 1.854848
Iteration  280, Loss: 1.861901
Iteration  300, Loss: 1.869761
Iteration  320, Loss: 1.861971
Iteration  340, Loss: 1.847721
Iteration  360, Loss: 1.811841
Iteration  380, Loss: 1.815572
Iteration  400, Loss: 1.832138
Iteration  420, Loss: 1.857014
Iteration  440, Loss: 1.853353
Iteration  460, Loss: 1.811177
Iteration  480, Loss: 1.837409
Iteration  500, Loss: 1.866915
Iteration  520, Loss: 1.869753
Iteration  540, Loss: 1.877122
Iteration  560, Loss: 1.886601
Iteration  580, Loss: 1.889726
Iteration  600, Loss: 1.883915
Iteration  620, Loss: 1.868430
Iteratio

KeyboardInterrupt: 

In [None]:

# ## Training
# iters = 4000
# alpha = 0.01

# for c in range(iters):
#     ## Forward Pass through predictor
#     for layer in predictor:
#       layer.training = False
#     x = X_train
#     print(x.shape)
#     for layer in predictor:
#       x = layer(x)
#     # Loss
#     Loss = f.cross_entropy(x, Y_train)


#     ## Forward Pass through updater
#     for layer in layers:
#       layer.training = True

#     # Full SVD
#     A = predictor[0].W.detach().cpu().numpy()  # 784 x 10
#     LR_concat = decomposition(A, k=1) # 794
#     print("LR_concat shape:", LR_concat.shape)

#     i = torch.stack([torch.tensor(LR_concat).to(device)])
#     print(i.shape)

#     for layer in layers:
#       i = layer(i)
#     print(i.shape)
    
#     L_update = i[:, :784].reshape(784, 1)
#     R_update = i[:, 784:].reshape(1, 10)
#     predictor_W_update = L_update @ R_update


#     ## Weight update for predicter
#     predictor[0].W+=predictor_W_update


#     # Calculating Gradient for updater model
#     for layer in layers:
#       layer.result.retain_grad() # This stores grad of layers like Tanh that have no params to update

#     for p in params:
#         p.grad = None

#     Loss.backward()

#     # Weight Update for updater model
#     for p in params:
#         p.data += -alpha * p.grad



#     if c % (iters/20) == 0:
#         print(Loss)


In [None]:
def accuracy(X, Y, layersv=predictor):
    for layer in layersv:
      layer.training=False
    # Forward
    x = X
    for layer in layersv:
      x = layer(x)
    probs = f.softmax(x, 1)
    answers = x.argmax(1)
    c = 0
    for a, y in zip(answers, Y):
        if a==y: c+=1
    return c / answers.shape[0] * 100

def loss(X, Y, layersv = predictor):
    x = X
    for layer in layersv:
      x = layer(x)
    return f.cross_entropy(x, Y)

print(f"train accuracy: {accuracy(X_train, Y_train)} | test accuracy: {accuracy(X_test, Y_test)}")
print(f"train loss: {loss(X_train, Y_train)} | test loss: {loss(X_test, Y_test)}")

train accuracy: 76.36833333333334 | test accuracy: 77.13
train loss: 1.5593186616897583 | test loss: 1.4697450399398804
