In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [1]:
class Linear:
  
  def __init__(self, fan_in, fan_out, bias=True):
    self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5 # note: kaiming init
    self.bias = torch.zeros(fan_out) if bias else None
  
  def __call__(self, x):
    self.out = x @ self.weight
    if self.bias is not None:
      self.out += self.bias
    return self.out
  
  def parameters(self):
    return [self.weight] + ([] if self.bias is None else [self.bias])

# -----------------------------------------------------------------------------------------------
class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    if self.training:
      if x.ndim == 2:
        dim = 0
      elif x.ndim == 3:
        dim = (0,1)
      xmean = x.mean(dim, keepdim=True) # batch mean
      xvar = x.var(dim, keepdim=True) # batch variance
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    # update the buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

# -----------------------------------------------------------------------------------------------
class Tanh:
  def __call__(self, x):
    self.out = torch.tanh(x)
    return self.out
  def parameters(self):
    return []

# -----------------------------------------------------------------------------------------------
class Embedding:
  
  def __init__(self, num_embeddings, embedding_dim):
    self.weight = torch.randn((num_embeddings, embedding_dim))
    
  def __call__(self, IX):
    self.out = self.weight[IX].transpose(1, 2)
    
    return self.out
  
  def parameters(self):
    return [self.weight]

# -----------------------------------------------------------------------------------------------
class FlattenConsecutive:
  
  def __init__(self, n):
    self.n = n
    
  def __call__(self, x):
    B, T, C = x.shape
    x = x.view(B, T//self.n, C*self.n)
    if x.shape[1] == 1:
      x = x.squeeze(1)
    self.out = x
    return self.out
  
  def parameters(self):
    return []

class Flatten:
    def __call__(self, x):
        self.out = x.view(x.shape[0], -1)
        return self.out
    def parameters(self):
        return []

# -----------------------------------------------------------------------------------------------
class Sequential:
  
  def __init__(self, layers):
    self.layers = layers
  
  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    self.out = x
    return self.out
  
  def parameters(self):
    # get parameters of all layers and stretch them out into one list
    return [p for layer in self.layers for p in layer.parameters()]

# --------------------------------------------
class Conv1d:
    def __init__(self, sequence_length, in_channels, out_channels, kernel=2, stride=1, dilation=1):
        self. sequence_length = sequence_length
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel = kernel
        self.stride = stride
        self.dilation = dilation
        self.filters = torch.randn((out_channels, in_channels, kernel)) * 0.1
        self.bias = torch.randn(out_channels) * 0
    def __call__(self, x):
        # Compute effective kernel size based on dilation 
        effective_kernel = ((kernel - 1) * self.dilation) + 1
        N, C, L = x.shape
        # create the sliding windows of the input 
        x_unfolded = x.unfold(2, self.effective_kernel, self.stride)
        
        # Extract dilated inputs from x_unfolded which used effective_kernel
        x_unfolded = x_unfolded[:, :, ]
        Lout = ((self.sequence_length - self.kernel) // self.stride) + 1

        # Before cross correlation, we need to broadcast the filters and the input correctly
        x_unfolded = x_unfolded.view(N, 1, C, Lout, self.kernel)
        filters = self.filters.view(1, self.out_channels, self.in_channels, 1, self.kernel)

        # Perform element wise multiplication
        self.out = torch.mul(x_unfolded, filters).sum((2, 4)) + self.bias.view(1, self.out_channels, 1)
        return self.out        
    
    def parameters(self): 
        return [self.filters]

class ReLu: 
    def __call__(self, x):
        self.out = torch.relu(x)
        return self.out

    def parameters(self):
        return []
        
class Transpose:
    def __call__(self, x):
        self.out = x.transpose(1, 2)
        return self.out
    
    def parameters(self):
        return []

In [18]:
a = torch.randn((1, 10, 4))
a.shape

torch.Size([1, 10, 4])

In [17]:
a.unfold(2, 2, 1).shape

torch.Size([1, 10, 3, 2])

In [85]:
kernel = 2
sequence_length = 10
torch.arange(10).view(1, 10).unfold(1, kernel, 1).shape
# L = (10 - 2) +  1 = 9 sequence_length or LOut without dilation
dilation = 2 
# effective kernel size 
effective_kernel = (kernel - 1) * dilation +  1
x = torch.arange(10).view(1, 10).unfold(1, effective_kernel, 1)[:, :, ::dilation]
Lout = (sequence_length - effective_kernel) +  1
Lout, x.shape

(8, torch.Size([1, 8, 2]))

In [83]:
filters = torch.randn((1, Lout, ))

torch.Size([1, 8, 2])

In [92]:
x = torch.randn((32, 10, 8))
x.unfold(2, 3, 1)[:, :, :, ::2].shape

torch.Size([32, 10, 6, 2])

In [130]:
x1 = torch.randn((32, 10, 8)) # Follows [N, C, L]
x = x1
# Mock conv1  
filters = torch.randn((1, 10, 2))

#unfold the input so it can be cross correlated with the filters 
x = x.unfold(2, 2, 1)
filters = filters.view(1, 10, 1, 2)
x = x.view(32, 10, 7, 2)
out = torch.mul(x, filters).sum((1, 3)) # Output of first conv1d layer is [32, 1, 7]

out = out.view(32, 1, 7)
out = torch.relu(out)

# Mock conv2
filters = torch.randn((1, 1, 2))
out = out.unfold(2, 2, 1)
filters = filters.view(1, 1, 1, 2)
out = out.view(32, 1, 6, 2)
out = torch.mul(out, filters)
out = out.sum((1, 3))
out = torch.relu(out.view(32, 1, 6))

out.shape, x1.shape

(torch.Size([32, 1, 6]), torch.Size([32, 10, 8]))

RuntimeError: only one dimension can be inferred

In [141]:
ws = torch.randn((1, 1, 1))
new_x = torch.mul(x1, ws).sum(1).view(32, 1, 8)

In [142]:
import torch.nn.functional as F
out_padded = F.pad(out, (1, 1))
out_padded.shape

torch.Size([32, 1, 8])

In [143]:
out_padded + new_x

tensor([[[-1.9552e-01,  1.0429e-02, -7.9978e-01, -3.6721e-01,  7.5553e-01,
           8.9859e-02,  4.6809e-01,  6.7790e-04]],

        [[ 3.5495e-01, -4.0436e-01, -2.3255e-01, -5.2633e-01, -2.0903e-01,
          -4.6143e-02,  1.3084e-01,  2.2296e-01]],

        [[-6.0859e-01,  1.2943e-01,  3.1384e-01,  7.3723e-01,  1.8838e-01,
           6.2714e-01, -1.2009e+00,  2.4364e-01]],

        [[ 1.3663e-01,  1.2647e+00,  5.1808e-01,  1.4569e-02, -9.5868e-01,
           4.9978e-01,  2.2269e-01,  1.3858e+00]],

        [[ 7.2894e-02,  1.8587e-01, -2.0079e-01, -7.4890e-01,  2.7596e-02,
           6.5100e-01, -5.4487e-01, -2.3686e-01]],

        [[ 6.0973e-01, -3.3955e-01,  5.6073e-01,  6.4402e-01,  4.9537e-01,
           6.1137e-01, -4.5211e-01, -5.8515e-01]],

        [[ 1.1916e+00,  7.2823e-02,  7.0781e-01,  1.1280e+00, -6.6807e-01,
          -1.0612e-01, -7.7295e-01,  1.0129e+00]],

        [[ 7.4925e-02, -4.9892e-01, -1.1025e-02, -1.1195e-01,  8.3276e-01,
           4.5315e-01, -1.0737e-01, 