In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [89]:
class Linear:
  
  def __init__(self, fan_in, fan_out, bias=True):
    self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5 # note: kaiming init
    self.bias = torch.zeros(fan_out) if bias else None
  
  def __call__(self, x):
    self.out = x @ self.weight
    if self.bias is not None:
      self.out += self.bias
    return self.out
  
  def parameters(self):
    return [self.weight] + ([] if self.bias is None else [self.bias])

# -----------------------------------------------------------------------------------------------
class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # Parameters (trainable via backprop)
    self.gamma = torch.ones(dim).view(1, -1, 1)  # Shape [1, C, 1]
    self.beta = torch.zeros(dim).view(1, -1, 1)  # Shape [1, C, 1]
    # Buffers (updated via momentum)
    self.running_mean = torch.zeros(dim)  # Shape [C]
    self.running_var = torch.ones(dim)    # Shape [C]
  
  def __call__(self, x):
    if self.training:
      # Compute mean and variance across batch and sequence length (dim=(0,2))
      xmean = x.mean(dim=(0, 2), keepdim=True)  # Shape [1, C, 1]
      xvar = x.var(dim=(0, 2), keepdim=True)    # Shape [1, C, 1]
    else:
      # Use running statistics for inference
      xmean = self.running_mean.view(1, -1, 1)  # Shape [1, C, 1]
      xvar = self.running_var.view(1, -1, 1)    # Shape [1, C, 1]
    
    # Normalize input
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps)  # Normalize to unit variance
    self.out = self.gamma * xhat + self.beta         # Scale and shift

    # Update running statistics during training
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean.squeeze()
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar.squeeze()
    
    return self.out
  
  def parameters(self):
    # Return trainable parameters
    return [self.gamma, self.beta]

# -----------------------------------------------------------------------------------------------
class Tanh:
  def __call__(self, x):
    self.out = torch.tanh(x)
    return self.out
  def parameters(self):
    return []

# -----------------------------------------------------------------------------------------------
class Embedding:
  
  def __init__(self, num_embeddings, embedding_dim):
    self.weight = torch.randn((num_embeddings, embedding_dim))
    
  def __call__(self, IX):
    self.out = self.weight[IX].transpose(1, 2)
    
    return self.out
  
  def parameters(self):
    return [self.weight]

# -----------------------------------------------------------------------------------------------
class FlattenConsecutive:
  
  def __init__(self, n):
    self.n = n
    
  def __call__(self, x):
    B, T, C = x.shape
    x = x.view(B, T//self.n, C*self.n)
    if x.shape[1] == 1:
      x = x.squeeze(1)
    self.out = x
    return self.out
  
  def parameters(self):
    return []

class Flatten:
    def __call__(self, x):
        self.out = x.view(x.shape[0], -1)
        return self.out
    def parameters(self):
        return []

# -----------------------------------------------------------------------------------------------
class Sequential:
  
  def __init__(self, layers):
    self.layers = layers
  
  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    self.out = x
    return self.out
  
  def parameters(self):
    # get parameters of all layers and stretch them out into one list
    return [p for layer in self.layers for p in layer.parameters()]

# --------------------------------------------
class Conv1d:
    def __init__(self, sequence_length, in_channels, out_channels, kernel=2, stride=1, dilation=1):
        self. sequence_length = sequence_length
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel = kernel
        self.stride = stride
        self.dilation = dilation
        self.filters = torch.randn((out_channels, in_channels, kernel)) * ((2 / (in_channels * kernel)) ** 0.5)
        self.bias = torch.randn(out_channels) * 0.01
        self.effective_kernel = ((self.kernel - 1) * self.dilation) + 1
        self.Lout = ((self.sequence_length - self.effective_kernel) // self.stride) + 1
    def __call__(self, x):
        # Compute effective kernel size based on dilation 
        # effective_kernel = ((self.kernel - 1) * self.dilation) + 1
        
        N, C, L = x.shape
        assert self.effective_kernel <= L
            
        # create the sliding windows of the input 
        x_unfolded = x.unfold(2, self.effective_kernel, self.stride)

        # Extract dilated inputs from x_unfolded which used effective_kernel. The shape of the unfolded vector is [N, C, L, effective_k] 
        # where L is the length of the sequence depending on the effective kernel. From the dimension of effective_kernel, we clip every 'dilated' index
        # If effective_kernel is 3 and dilation is 2, [1, 2, 3] will result in [1, 3]. [1,3] has length of 2, which is equal to actual kernel value
        x_unfolded = x_unfolded[:, :, :, ::self.dilation]

        # The dilation also changes the sequence length, since effective kernel value changes with dilation > 1. 
        # Compute Lout based on effective_kernel
        
        # Lout = ((self.sequence_length - self.effective_kernel) // self.stride) + 1
        
        # Before cross correlation, we need to broadcast the filters and the input correctly
        x_unfolded = x_unfolded.view(N, 1, C, self.Lout, self.kernel)
        filters = self.filters.view(1, self.out_channels, self.in_channels, 1, self.kernel)

        # Perform element wise multiplication
        self.out = torch.mul(x_unfolded, filters).sum((2, 4)) + self.bias.view(1, self.out_channels, 1)
        return self.out        
    
    def parameters(self): 
        return [self.filters] + [self.bias]

class ReLu: 
    def __call__(self, x):
        self.out = torch.relu(x)
        return self.out

    def parameters(self):
        return []
        
class Transpose:
    def __call__(self, x):
        self.out = x.transpose(1, 2)
        return self.out
    
    def parameters(self):
        return []

In [6]:
torch.backends.mps.is_available()

True

In [90]:
x = torch.randn((32, 10, 8))

In [98]:
c = Conv1d(8, 10, 10)
o = c(x)


In [99]:
o.shape

torch.Size([32, 10, 7])

In [None]:
def get_Lout(sequence_length, kernel, dilation, stride):
    effective_kernel = (kernel - 1) * dilation + 1
    Lout = ((sequence_length - effective_kernel) // stride) + 1
    return Lout
    
### Let's redefine the model with conv layers 
n_embedding = 24

#h1
n_h1_fanout = 100

#h2
n_h2_fanout = 100

model = Sequential([
    Embedding(vocab_size, n_embedding), 
    Residual([
        Conv1d(
                sequence_length=block_size,
                in_channels=n_embedding,
                out_channels=n_embedding,
                kernel=2,
            ),
    ]),
    BatchNorm1d(n_embedding),
    ReLu(),
    Conv1d(
                sequence_length=block_size,
                in_channels=n_embedding,
                out_channels=n_embedding,
                kernel=2,
            ),
    BatchNorm1d(n_embedding),
    ReLu(), 

    Conv1d(
                sequence_length=14,
                in_channels=n_embedding,
                out_channels=n_embedding,
                kernel=2,
            ),
    BatchNorm1d(n_embedding),
    ReLu(),
    
    # Output of residual will be the [out_channels, input_sequence_length of the layer before the residual layer]
    Flatten(), Linear(fan_in=13 * n_embedding, fan_out=n_h1_fanout), Tanh(),
    Linear(fan_in=n_h1_fanout, fan_out=n_h2_fanout), Tanh(),
    Linear(fan_in=n_h1_fanout, fan_out=vocab_size)
])



with torch.no_grad():
    model.layers[-1].weight *= 0.1

# parameters = [p for layer in layers for p in layer.parameters()]
print(f"parameters: {sum(p.nelement() for p in model.parameters())}")

for p in model.parameters():
    p.requires_grad = True

In [151]:
wxh = torch.randn((10, 10), requires_grad=True)


whh = torch.randn((10, 10), requires_grad=True)

who = torch.randn((10, 1), requires_grad=True)


In [152]:
import torch.nn as nn 
import torch.nn.functional as F

Ht = torch.tanh(x @ wxh + H @ whh)
H = Ht
logits = torch.tanh(Ht @ who)

loss = F.mse_loss(logits, y)


who.grad = None
whh.grad = None
wxh.grad = None

loss.backward()

who.data += -0.1 * who.grad
whh.data += -0.1 * whh.grad
wxh.data += -0.1 * wxh.grad

tensor([[ 6.9904e-08, -7.2189e-04,  9.9771e-07, -2.8659e-06,  8.7122e-09,
          9.8482e-05,  6.4838e-05,  5.7035e-05,  0.0000e+00, -6.4472e-03],
        [-6.9234e-08,  7.1498e-04, -9.8815e-07,  2.8384e-06, -8.6287e-09,
         -9.7539e-05, -6.4217e-05, -5.6489e-05,  0.0000e+00,  6.3854e-03],
        [-1.4945e-07,  1.5434e-03, -2.1330e-06,  6.1270e-06, -1.8626e-08,
         -2.1055e-04, -1.3862e-04, -1.2194e-04,  0.0000e+00,  1.3784e-02],
        [-5.7292e-08,  5.9165e-04, -8.1770e-07,  2.3488e-06, -7.1403e-09,
         -8.0714e-05, -5.3140e-05, -4.6745e-05,  0.0000e+00,  5.2840e-03],
        [-4.3347e-08,  4.4764e-04, -6.1867e-07,  1.7771e-06, -5.4023e-09,
         -6.1068e-05, -4.0206e-05, -3.5367e-05,  0.0000e+00,  3.9978e-03],
        [ 1.2471e-08, -1.2879e-04,  1.7799e-07, -5.1127e-07,  1.5543e-09,
          1.7569e-05,  1.1567e-05,  1.0175e-05,  0.0000e+00, -1.1502e-03],
        [ 3.8416e-07, -3.9672e-03,  5.4830e-06, -1.5750e-05,  4.7879e-08,
          5.4122e-04,  3.5633e-0

In [165]:
x = torch.randn((2, 10, 8)) # [N, C, L]

for i in range(x.size(2)):
    xi = x[:, :, i]
    print(xi.shape)


torch.Size([2, 10])
torch.Size([2, 10])
torch.Size([2, 10])
torch.Size([2, 10])
torch.Size([2, 10])
torch.Size([2, 10])
torch.Size([2, 10])
torch.Size([2, 10])


In [292]:
x = torch.randn((1, 10, 8)) # [N, C, L]
wxh = torch.randn((10, 10), requires_grad=True) # Weight of input transformation
y = torch.randn((1, 8))
whh = torch.randn((10, 10), requires_grad=True)

who = torch.randn((10, 1), requires_grad=True) # Output layer weights. Reduces channels to 1. So we can use mse_loss just for understanding purposes. 

In [293]:
y.shape

torch.Size([1, 8])

In [311]:
for k in range(10):
    H = torch.randn((1, 10))
    loss = 0
    for i in range(x.size(2)):
        xi = x[:, :, i]
        xiw = xi @ wxh
    
        Hw = H @ whh
        Ht = torch.tanh(xiw + Hw)
        H = Ht
    
        logits = Ht @ who
        
        loss += F.mse_loss(logits, y[:, i].view(1, 1))
    
    loss = loss / 8
    
    wxh.grad = None
    whh.grad = None
    who.grad = None 
    
    loss.backward()
    
    wxh.data += -0.0001 * wxh.grad 
    whh.data += -0.0001 * whh.grad
    who.data += -0.0001 * who.grad
    print(f"Loss: {loss.item()}")

Loss: 0.10798269510269165
Loss: 0.46418341994285583
Loss: 0.4001181721687317
Loss: 0.6215532422065735
Loss: 0.1756177693605423
Loss: 0.567419171333313
Loss: 1.5683540105819702
Loss: 0.8218318819999695
Loss: 1.0374850034713745
Loss: 0.11863193660974503


In [291]:
y[:, i].shape

torch.Size([1])

### Let's try this for a batch size greater than 1

In [442]:
x = torch.randn((32, 10, 8)) # [N, C, L]
wxh = torch.randn((10, 10), requires_grad=True) # Weight of input transformation
y = torch.randn((32, 8))
whh = torch.randn((10, 10), requires_grad=True)

who = torch.randn((10, 1), requires_grad=True) # Output layer weights. Reduces channels to 1. So we can use mse_loss just for understanding purposes. 

In [458]:
loss = 0
H = torch.randn((32, 10))
for i in range(x.size(2)):
    xi = x[:, :, i]
    xiw = xi @ wxh
    Hw = H @ whh
    Ht = torch.tanh(xiw + Hw)
    H = Ht
    logits = Ht @ who
    mse = ((logits - y[:, i].view(32, 1)) ** 2)
    loss += (mse.sum() / 32)

loss = loss / 8

In [459]:
loss

tensor(6.6472, grad_fn=<DivBackward0>)

In [470]:
x = torch.randn((32, 10, 8)) # [N, C, L]
wxh = torch.randn((10, 10), requires_grad=True) # Weight of input transformation
y = torch.randn((32, 8))
whh = torch.randn((10, 10), requires_grad=True)

who = torch.randn((10, 1), requires_grad=True) # Output layer weights. Reduces channels to 1. So we can use mse_loss just for understanding purposes. 

In [489]:
x = torch.randn((32, 10, 8)) # [N, C, L]
wxh = torch.randn((10, 10), requires_grad=True) # Weight of input transformation
y = torch.randn((32, 8))
whh = torch.randn((10, 10), requires_grad=True)

who = torch.randn((10, 1), requires_grad=True) # Output layer weights. Reduces channels to 1. So we can use mse_loss just for understanding purposes. 

for e in range(1000):
    loss = 0
    H = torch.zeros((32, 10))
    for i in range(x.size(2)):
        xi = x[:, :, i]
        xiw = xi @ wxh
        Hw = H @ whh
        Ht = torch.tanh(xiw + Hw)
        H = Ht
        logits = Ht @ who
        mse = ((logits - y[:, i].view(32, 1)) ** 2)
        loss += (mse.sum() / 32)
    
    loss = loss / 8
    print(f"Loss: {loss}")
    wxh.grad = None 
    whh.grad = None 
    who.grad = None 
    
    loss.backward()
    who.data += -0.1 * who.grad
    whh.data += -0.1 * whh.grad
    wxh.data += -0.1 * wxh.grad
    

Loss: 0.32360967993736267
Loss: 0.3211156129837036
Loss: 0.3226601481437683
Loss: 0.32002511620521545
Loss: 0.3211105465888977
Loss: 0.3190150260925293
Loss: 0.3186630606651306
Loss: 0.3209323585033417
Loss: 0.3302017152309418
Loss: 0.3243445158004761
Loss: 0.32479310035705566
Loss: 0.32001444697380066
Loss: 0.3282221853733063
Loss: 0.32079848647117615
Loss: 0.31698065996170044
Loss: 0.31668299436569214
Loss: 0.3178521692752838
Loss: 0.3196158707141876
Loss: 0.32613605260849
Loss: 0.31975990533828735
Loss: 0.3188309371471405
Loss: 0.31719425320625305
Loss: 0.32432639598846436
Loss: 0.3203420042991638
Loss: 0.3164690434932709
Loss: 0.31519240140914917
Loss: 0.3159310221672058
Loss: 0.31561172008514404
Loss: 0.3231756389141083
Loss: 0.31855830550193787
Loss: 0.3169868290424347
Loss: 0.31573304533958435
Loss: 0.32558998465538025
Loss: 0.32174164056777954
Loss: 0.3188849687576294
Loss: 0.3155690133571625
Loss: 0.31760549545288086
Loss: 0.31364965438842773
Loss: 0.3131725490093231
Loss: 0.3