In [66]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
import time
import numpy as np
import plotly.express as px


def visualise_tensor(data):
    data = data.detach().numpy()

    ## Create colormap (red to blue)
    colormap = 'RdBu_r'
    # Plot the matrix
    fig = px.imshow(data, color_continuous_scale=colormap, zmin=-1, zmax=1, )
    fig.show()
    fig.data = ()

input_size = 20
hidden_layer_size= 5
output_size = input_size
learning_rate = 0.0001
max_iters = 300000
dataset_size = 1024*8
sparsity = 0.01 # appears only one in 100 times
batch_size = 32

decay_factor = 0.95
iters = []
loss_value = []
val_loss_value = []

imp_vector =torch.tensor([0.9**i for i in range(input_size)])


def get_l1_penalty():
    l1_penalty = 0
    for param in model.parameters():
        if param.requires_grad:
            l1_penalty += torch.norm(param, p=1)
    return l1_penalty

def sample_data():
    # TODO Implement mini batch
    a,b = training_dataset.shape
    sparsity_mat = torch.bernoulli(torch.full((a,b), sparsity))
    X = training_dataset*sparsity_mat
    return X, X

def sample_validation_data():
    a,b = validation_dataset.shape
    sparsity_mat = torch.bernoulli(torch.full((a,b), sparsity))
    X = validation_dataset*sparsity_mat
    return X, X


def get_loss(target, output):
    loss = (imp_vector*((target-output)**2)).mean()
    return loss


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(input_size, hidden_layer_size, bias= False)
        self.layer2 = nn.Linear(hidden_layer_size, output_size)
        self.relu = nn.ReLU()
    
    def forward(self, input, targets = None):
         # Input is (B,I)
        logits = self.layer1(input) # (B, dim)
        
        #logits = self.layer2(logits)
        layer_1_weight = self.layer1.weight.data # (dim, I)
        logits = logits @ layer_1_weight # (B, dim) * (dim, I) -> (B,I)
        logits += self.layer2.bias

        logits = self.relu(logits)
        if targets is None:
            return logits
        else:
            loss = get_loss(targets, logits)
            return logits, loss

training_dataset = torch.rand(dataset_size, input_size)
validation_dataset = torch.rand(int(dataset_size/2), input_size)
output = input


model = Model()

LOAD_MODEL = False
PATH = "../models/basic_linear_0.99.bin"
if LOAD_MODEL:
    model = Model()
    model.load_state_dict(torch.load(PATH))
    model.eval()
else:
    optimizer = torch.optim.Adam(model.parameters() , learning_rate, weight_decay= 1e-5)
    start_time = time.time()
    step_value = max_iters/20
    for i in range(1, max_iters):
        X, Y = sample_data()
        logits, loss = model(X, Y)
    
        if i%step_value ==0 :
            model.eval()
            with torch.no_grad():
                X_V,Y_V = sample_validation_data()
                _, val_loss = model(X_V, Y_V)
                iters.append(i)
                loss_value.append(loss.item())
                val_loss_value.append(val_loss.item())
                print(f"iter:{i} training loss: {loss.item()}, val loss: {val_loss.item()}")
            model.train()

        optimizer.zero_grad(set_to_none=True)
        #loss+= get_l1_penalty()
        loss.backward()
        optimizer.step()


    end_time = time.time()
    print(f"Took {end_time-start_time}s for {max_iters} epochs")

    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.plot(iters,loss_value, color='blue', label="Training")
    plt.plot(iters, val_loss_value, "red", label = "validation")
    plt.legend()
    plt.show()




iter:15000 training loss: 0.0003155978338327259, val loss: 0.00030334931216202676
iter:30000 training loss: 0.0002725335652939975, val loss: 0.0002409634180366993
iter:45000 training loss: 0.0002533670049160719, val loss: 0.00030243751825764775
iter:60000 training loss: 0.0002789786667563021, val loss: 0.0002511065104044974
iter:75000 training loss: 0.0002539022243581712, val loss: 0.00026827564579434693
iter:90000 training loss: 0.0002923064457718283, val loss: 0.00026593025540933013
iter:105000 training loss: 0.00026889954460784793, val loss: 0.00026642094599083066
iter:120000 training loss: 0.00025493904831819236, val loss: 0.00025779198040254414
iter:135000 training loss: 0.0002746892278082669, val loss: 0.0002743047080002725


KeyboardInterrupt: 

In [67]:


w1 = model.layer1.weight
w2 = model.layer2.weight
b1 = model.layer1.bias

b2 = model.layer2.bias
visualise_tensor(torch.transpose(w1,0,1)@w1)
visualise_tensor(b2.data.reshape(-1,1))

In [60]:
norm = [0 for i in range(input_size)]
for i in range(input_size):
    idx = torch.zeros(input_size)
    idx[i] = 1
    idx = w1@idx
    norm[i] = idx.norm()
norm = torch.tensor(norm)
visualise_tensor(norm.reshape(-1,1))

In [61]:
dot = [0 for i in range(input_size)]
for i in range(input_size):
    idx = torch.zeros(input_size)
    idx[i] = 1
    s = 0
    embed_a = w1@idx
    for j in range(input_size):
        if i==j:
            continue
        idx = torch.zeros(input_size)
        idx[j] = 1
        embed_b = w1@idx
        s+= (torch.dot(embed_a,embed_b))**2
    
    dot[i] = s

dot = torch.tensor(dot)
print(dot)
visualise_tensor(dot.reshape(-1,1))


tensor([2.9976, 2.8351, 2.9370, 3.0515, 3.0011, 2.9778, 2.8924, 2.8668, 2.7888,
        2.7806, 2.7520, 2.5444, 2.5298, 2.2452, 2.3704, 0.6492, 1.6259, 0.7086,
        0.2118, 0.2062])


In [62]:
PATH = "../models/basic_relu_0.03.bin"
torch.save(model.state_dict(), PATH)


Model(
  (layer1): Linear(in_features=20, out_features=5, bias=False)
  (layer2): Linear(in_features=5, out_features=20, bias=True)
  (relu): ReLU()
)

In [5]:
model(torch.rand(1,20))

tensor([[0.0189, 0.6945, 0.0000, 0.0000, 0.2214, 0.0000, 0.8834, 0.2291, 0.4173,
         0.4626, 0.4563, 0.0000, 0.4424, 0.0000, 0.4148, 0.0000, 0.3660, 0.0000,
         0.2745, 0.0000]], grad_fn=<ReluBackward0>)