In [1]:
#!pip install datasets

In [2]:
## Please install torch and datasets
import torch
from torchvision.transforms import functional as t
import torch.nn.functional as f
from datasets import load_dataset
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:
## Loading our dataset
ds = load_dataset("ylecun/mnist")

In [4]:
## Data splits

X_train_p = ds["train"]["image"]
Y_train = ds["train"]["label"]
X_test_p = ds["test"]["image"]
Y_test = ds["test"]["label"]

In [5]:
## PIL to Tensors

X_train = [t.pil_to_tensor(x) for x in X_train_p]
X_test = [t.pil_to_tensor(x) for x in X_test_p]
X_train = torch.stack(X_train).to(device)
X_test = torch.stack(X_test).to(device)
print(X_train.shape, X_test.shape)

torch.Size([60000, 1, 28, 28]) torch.Size([10000, 1, 28, 28])


In [6]:
## Fixing the shape

X_train = X_train.view(-1, 28, 28)
X_test = X_test.view(-1, 28, 28)
print(X_train.shape, X_test.shape)

torch.Size([60000, 28, 28]) torch.Size([10000, 28, 28])


In [7]:
## Making labels into tensors

Y_train = torch.tensor(Y_train).to(device)
Y_test = torch.tensor(Y_test).to(device)

In [8]:
## Flattening the image as DNN takes flat tensor as input

X_train = X_train.view(-1, 784).float() / 255.0
X_test = X_test.view(-1, 784).float() / 255.0
print(X_train.shape, X_test.shape)

torch.Size([60000, 784]) torch.Size([10000, 784])


In [9]:
class Linear():
  def __init__(self, input_dims, output_dims, B=True, last=False):
    self.training = True
    self.W = (torch.randn(input_dims, output_dims) * (5/3) / (input_dims**0.5)).to(device) if not last else (torch.randn(input_dims, output_dims) * (5/3) / (input_dims**0.5) * 0.1).to(device)
    if B: self.B = torch.randn(output_dims).to(device) if not last else (torch.randn(output_dims) * 0.1).to(device)
    else: self.B = torch.tensor([]).to(device)

  def __call__(self, x):
    if not torch.equal(self.B, torch.tensor([]).to(device)): self.result = x@self.W + self.B
    else: self.result = x@self.W
    return self.result

  def parameters(self):
    return [self.W] + [self.B]


class Tanh():
  def __init__(self):
    self.training = True
    return None

  def __call__(self, x):
    self.result = torch.tanh(x)
    return self.result

  def parameters(self):
    return []

class Dropout():
  def __init__(self, batch_size, output_dims, rate=0.9):
    self.training = True
    self.rate = rate
    self.factor = (torch.rand(batch_size, output_dims) < self.rate).int().to(device)
    return None

  def __call__(self, x):
    if self.training: self.result = x * self.factor
    else: self.result = x
    return self.result

  def parameters(self):
    return []




In [10]:
# Hyperparameters

K = 1  # Low-rank decomposition rank - TUNE THIS!
k = K

print(f"Low-rank decomposition rank set to: k={K}")
print(f"For a 784x10 matrix, this uses {784*K + K*10} elements instead of {784*10}")
print(f"Compression ratio: {(784*K + K*10)/(784*10):.3f}")

input_dim = 784  # MNIST input dimension
output_dim = 10  # Number of classes

# Calculate the size of the concatenated low-rank factors
# For a matrix of size (input_dim x output_dim) with rank k:
# L has shape (input_dim x k), R has shape (k x output_dim)
# Total elements = input_dim * k + k * output_dim
decomposed_size = input_dim * k + k * output_dim

print(f"Using k={k} for low-rank decomposition")
print(f"Decomposed representation size: {decomposed_size}")

# Network architecture
n1 = 512
n2 = 256
n3 = 512
batch_size = 1

layers = [
    Linear(decomposed_size, n1), Tanh(), Dropout(batch_size, n1),
    Linear(n1, n2), Tanh(), Dropout(batch_size, n2),
    Linear(n2, n3), Tanh(), Dropout(batch_size, n3),
    Linear(n3, decomposed_size, last=True),  # Output same size as input
]

predictor = [
    Linear(input_dim, output_dim, last=True, B=False)
]

params = [p for layer in layers for p in layer.parameters()]
numparams = 0
for p in params:
    p.requires_grad = True
    numparams += p.numel()
print(f"Number of updater parameters: {numparams}")
print(f"Predictor parameters: {predictor[0].W.numel()}")
print(f"Updater network: {decomposed_size} -> {n1} -> {n2} -> {n3} -> {decomposed_size}")

Low-rank decomposition rank set to: k=1
For a 784x10 matrix, this uses 794 elements instead of 7840
Compression ratio: 0.101
Using k=1 for low-rank decomposition
Decomposed representation size: 794
Number of updater parameters: 1077274
Predictor parameters: 7840
Updater network: 794 -> 512 -> 256 -> 512 -> 794


In [11]:
import numpy as np
from numpy.linalg import svd
from scipy.linalg import svd as scipy_svd



def decomposition(A, k=1):
    # Check for NaN or Inf values and handle them
    if not np.isfinite(A).all():
        print("Warning: Found NaN/Inf in weights, replacing with small random values")
        A = np.where(np.isfinite(A), A, np.random.normal(0, 0.01, A.shape))
    
    # Clip extreme values to prevent numerical issues
    A_clipped = np.clip(A, -5, 5)
    
    # Simple and robust approach - just use random factorization if SVD fails
    try:
        U, S, VT = svd(A_clipped, full_matrices=False)
    except:
        print("Warning: SVD failed, using random low-rank approximation")
        # Create a simple low-rank factorization
        U = np.random.normal(0, 0.1, (A.shape[0], min(k, min(A.shape))))
        S = np.ones(min(k, min(A.shape)))
        VT = np.random.normal(0, 0.1, (min(k, min(A.shape)), A.shape[1]))

    # Truncate to rank-k
    U_k = U[:, :k]                # (784 x k)
    S_k = np.diag(S[:k])          # (k x k)
    VT_k = VT[:k, :]              # (k x 10)

    # Factor A_k = L @ R, where L and R are low-rank factors
    sqrt_S_k = np.sqrt(np.abs(S_k))       # (k x k) - use abs to avoid sqrt of negative
    L = U_k @ sqrt_S_k            # (784 x k)
    R = sqrt_S_k @ VT_k           # (k x 10)

    L_flat = L.flatten()
    R_flat = R.flatten()
    LR_concat = np.concatenate([L_flat, R_flat])

    return LR_concat


In [12]:
# ## Training - Adaptable Version with hyperparameter k
# iters = 2000  # Reduced for testing
# alpha = 0.01

# # First, let's create fresh parameter references to avoid any lingering graph issues
# params = [p for layer in layers for p in layer.parameters()]
# for p in params:
#     p.requires_grad = True

# print(f"Starting training with k={k}")
# print(f"Starting training with {len(params)} updater parameters")
# print(f"Predictor weight shape: {predictor[0].W.shape}")
# print(f"Decomposed size: {decomposed_size}")

# for c in range(iters):
#     # Clear all gradients
#     for p in params:
#         if p.grad is not None:
#             p.grad.zero_()

#     ## Step 1: Get current predictor weights and decompose them
#     with torch.no_grad():  # This prevents building unnecessary graphs
#         current_W = predictor[0].W.detach().cpu().numpy()  # input_dim x output_dim
#         LR_concat = decomposition(current_W, k=k)  # Use hyperparameter k
    
#     ## Step 2: Forward pass through updater
#     for layer in layers:
#         layer.training = True
    
#     # Convert decomposed weights to tensor and pass through updater
#     updater_input = torch.tensor(LR_concat, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
    
#     updater_output = updater_input
#     for layer in layers:
#         updater_output = layer(updater_output)
    
#     # Reconstruct weight update from updater output (adaptable to k)
#     L_size = input_dim * k  # Size of L matrix flattened
#     R_size = k * output_dim  # Size of R matrix flattened
    
#     L_update = updater_output[:, :L_size].reshape(input_dim, k)  # input_dim x k
#     R_update = updater_output[:, L_size:L_size+R_size].reshape(k, output_dim)  # k x output_dim
#     W_update = L_update @ R_update  # Reconstruct full matrix
    
#     ## Step 3: Compute loss with updated weights
#     for layer in predictor:
#         layer.training = False
    
#     # Apply weight update and compute forward pass
#     updated_weights = predictor[0].W + W_update
#     predictions = X_train @ updated_weights
#     loss = f.cross_entropy(predictions, Y_train)
    
#     ## Step 4: Backward pass and update
#     loss.backward()
    
#     # Update predictor weights (detached to avoid double backward)
#     with torch.no_grad():
#         # Clip the weight update to prevent extreme values
#         W_update_clipped = torch.clamp(W_update.detach(), -1.0, 1.0)
#         predictor[0].W += W_update_clipped
#         # Clip the final weights to keep them reasonable
#         predictor[0].W.clamp_(-5.0, 5.0)
    
#     # Update updater parameters
#     for p in params:
#         if p.grad is not None:
#             p.data -= alpha * p.grad

#     if c % (iters//10) == 0:
#         print(f"Iteration {c:4d}, Loss: {loss.item():.6f}")

# print("Training completed!")

In [13]:
## Training - Fixed Version
iters = 4000  # Reduced for testing
alpha = 0.01

# First, let's create fresh parameter references to avoid any lingering graph issues
params = [p for layer in layers for p in layer.parameters()]
for p in params:
    p.requires_grad = True

print(f"Starting training with {len(params)} updater parameters")
print(f"Predictor weight shape: {predictor[0].W.shape}")

for c in range(iters):
    # Clear all gradients
    for p in params:
        if p.grad is not None:
            p.grad.zero_()

    ## Step 1: Get current predictor weights and decompose them
    with torch.no_grad():  # This prevents building unnecessary graphs
        current_W = predictor[0].W.detach().cpu().numpy()  # 784 x 10
        LR_concat = decomposition(current_W, k=1)  # 794 elements
    
    ## Step 2: Forward pass through updater
    for layer in layers:
        layer.training = True
    
    # Convert decomposed weights to tensor and pass through updater
    updater_input = torch.tensor(LR_concat, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device)
    
    updater_output = updater_input
    for layer in layers:
        updater_output = layer(updater_output)
    
    # Reconstruct weight update from updater output
    L_update = updater_output[:, :784].reshape(784, 1)
    R_update = updater_output[:, 784:].reshape(1, 10)
    W_update = L_update @ R_update
    
    ## Step 3: Compute loss with updated weights
    for layer in predictor:
        layer.training = False
    
    # Apply weight update and compute forward pass
    updated_weights = predictor[0].W + W_update
    predictions = X_train @ updated_weights
    loss = f.cross_entropy(predictions, Y_train)
    
    ## Step 4: Backward pass and update
    loss.backward()
    
    # Update predictor weights (detached to avoid double backward)
    with torch.no_grad():
        predictor[0].W += W_update.detach()
    
    # Update updater parameters
    for p in params:
        if p.grad is not None:
            p.data -= alpha * p.grad

    if c % 20 == 0:
        print(f"Iteration {c:4d}, Loss: {loss.item():.6f}")

print("Training completed!")

Starting training with 8 updater parameters
Predictor weight shape: torch.Size([784, 10])
Iteration    0, Loss: 2.286718
Iteration   20, Loss: inf
Iteration   40, Loss: nan
Iteration   60, Loss: nan
Iteration   80, Loss: nan
Iteration  100, Loss: nan
Iteration  120, Loss: nan
Iteration  140, Loss: nan
Iteration  160, Loss: nan
Iteration  180, Loss: nan
Iteration  200, Loss: nan
Iteration  220, Loss: nan
Iteration  240, Loss: nan
Iteration  260, Loss: nan
Iteration  280, Loss: nan
Iteration  300, Loss: nan
Iteration  320, Loss: nan
Iteration  340, Loss: nan
Iteration  360, Loss: nan
Iteration  380, Loss: nan
Iteration  400, Loss: nan
Iteration  420, Loss: nan
Iteration  440, Loss: nan
Iteration  460, Loss: nan
Iteration  480, Loss: nan
Iteration  500, Loss: nan
Iteration  520, Loss: nan
Iteration  540, Loss: nan
Iteration  560, Loss: nan
Iteration  580, Loss: nan
Iteration  600, Loss: nan
Iteration  620, Loss: nan
Iteration  640, Loss: nan
Iteration  660, Loss: nan
Iteration  680, Loss:

In [15]:

## Training
iters = 4000
alpha = 0.01

for c in range(iters):
    ## Forward Pass through predictor
    for layer in predictor:
      layer.training = False
    x = X_train
    print(x.shape)
    for layer in predictor:
      x = layer(x)
    # Loss
    Loss = f.cross_entropy(x, Y_train)


    ## Forward Pass through updater
    for layer in layers:
      layer.training = True

    # Full SVD
    A = predictor[0].W.detach().cpu().numpy()  # 784 x 10
    print(predictor[0].W.grad)
    #A = predictor[0].W.grad.detach().cpu().numpy()
    LR_concat = decomposition(A, k=1) # 794
    print("LR_concat shape:", LR_concat.shape)

    i = torch.stack([torch.tensor(LR_concat).to(device)])
    print(i.shape)

    for layer in layers:
      i = layer(i)
    print(i.shape)
    
    L_update = i[:, :784].reshape(784, 1)
    R_update = i[:, 784:].reshape(1, 10)
    predictor_W_update = L_update @ R_update


    ## Weight update for predicter
    predictor[0].W+=predictor_W_update


    # Calculating Gradient for updater model
    for layer in layers:
      layer.result.retain_grad() # This stores grad of layers like Tanh that have no params to update

    for p in params:
        p.grad = None

    Loss.backward()

    # Weight Update for updater model
    for p in params:
        p.data += -alpha * p.grad



    if c % (iters/20) == 0:
        print(Loss)


torch.Size([60000, 784])
None
LR_concat shape: (794,)
torch.Size([1, 794])


RuntimeError: expected mat1 and mat2 to have the same dtype, but got: double != float

In [None]:
def accuracy(X, Y, layersv=predictor):
    for layer in layersv:
      layer.training=False
    # Forward
    x = X
    for layer in layersv:
      x = layer(x)
    probs = f.softmax(x, 1)
    answers = x.argmax(1)
    c = 0
    for a, y in zip(answers, Y):
        if a==y: c+=1
    return c / answers.shape[0] * 100

def loss(X, Y, layersv = predictor):
    x = X
    for layer in layersv:
      x = layer(x)
    return f.cross_entropy(x, Y)

print(f"train accuracy: {accuracy(X_train, Y_train)} | test accuracy: {accuracy(X_test, Y_test)}")
print(f"train loss: {loss(X_train, Y_train)} | test loss: {loss(X_test, Y_test)}")

train accuracy: 9.871666666666666 | test accuracy: 9.8
train loss: nan | test loss: nan
