<a href="https://colab.research.google.com/github/vidushiMaheshwari/nGPT/blob/main/nGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchao

In [None]:
!pip install torchtune

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

In [4]:
from torchtune.modules import RotaryPositionalEmbeddings

## Building Layers
1. `Norm`: Used to ensure that `NLinear`'s backpropagation always returns a normalized weight matrix. the wight parameter is registered with this layer's forward function.
2. `NLinear`: Linear layer with the constraint that its weight matrix will always be normalized.
3. `NFeedForward`: Feedforward network with two parallel channels for gated activation. Uses SiLU non linearity on output of one gate and multiplies it by the output of the other. All weights are handled by `NLinear` and the intermediate outputs are normalized.
3. `Scale`: Layer with scaling factor as a trainable parameter. Used for mimicking LERP's eigen values as well as scaling query and key matrices in Attention layer
4. `NAttention`: Classic attention layer with added constraint that all Embedding and Weight matrices are `NLinear` and all intermediate outputs are normalized.

In [5]:
class Norm(nn.Module):
  def __init__(self, norm_dim=-1) -> None:
    super().__init__()
    self.norm_dim = norm_dim

  def forward(self, x):
    return F.normalize(x, p=2, dim=self.norm_dim)

In [6]:
class NLinear(nn.Module):
  def __init__(self, dim_in, dim_out, norm_dim=-1) -> None:
    super().__init__()
    self.linear = nn.Linear(dim_in, dim_out, bias=False)

    # To ensure that the backpropagation returns a normalized matrix, I need to
    # register the backprop of linear layer's weights

    nn.utils.parametrize.register_parametrization(
        self.linear,
        "weight",
        Norm(norm_dim)
    )

    # The random initialized weights of the linear matrix should be normed even right now
    self.norm_weights_init_(norm_dim)

  @torch.no_grad
  def norm_weights_init_(self, norm_dim):
    self.linear.parametrizations.weight.original.copy_(self.linear.weight)
    # print(torch.norm(self.linear.weight, dim=norm_dim, p=2))

  def forward(self, x):
    return F.normalize(self.linear(x), p=2, dim=-1)


In [7]:
class Scale(nn.Module):
  # Scaling is done so that we can still use the existing non linearities (like SiLU --
  # without scaling these non linearities don't have much to work with)

  # s_a is a trainable vector with two scalars, init and scale. The initial value
  # is scale and its value is restored in forward by multiplying init / scale
  def __init__(self, dim_in, scale, init):
    super().__init__()
    self.scale = nn.Parameter(torch.ones(dim_in) * scale)
    self.init = init/scale # This is a constant value

  def forward(self, x):
    return x * self.scale * self.init

In [8]:
class NFeedForward(nn.Module):
  # This is a normal feedfoward with two parallel channels and gated activation
  def __init__(self, dim_in, dim_hidden, scale_gated=1.0, scale_hidden=1.0, scale_gated_init=1.0, scale_hidden_init=1.0) -> None:
    super().__init__()
    self._linear_hidden = NLinear(dim_in, dim_hidden)
    self._linear_gated = NLinear(dim_in, dim_hidden)
    self._scale_hidden = Scale(dim_hidden, scale_gated, scale_gated_init)
    self._scale_gated = Scale(dim_hidden, scale_hidden * (dim_in ** 0.5), scale_hidden_init)
    self._linear_out = NLinear(dim_hidden, dim_in)

  def forward(self, x):
    u = self._linear_hidden(x)
    v = self._linear_gated(x)
    u = self._scale_hidden(u)
    v = self._scale_gated(v)
    non_linearity = F.silu(u) * v
    return self._linear_out(non_linearity)


In [9]:
class NAttention(nn.Module):
  def __init__(self, dim, n_heads=8, dim_head=64, max_seq_length=512, s_qk_init=1.0, s_qk_scale = 1.0, is_causal=False) -> None:
    # injection of positional information by RoPE distorts q and k. We propose to
    # additionally normalize q and k, ensuring that the dot product of every query and key is under control
    super().__init__()


    self.n_heads = n_heads
    self.dim_head = dim_head

    dim_out = dim_head * n_heads
    self._linear_q = NLinear(dim, dim_out)
    self._linear_k = NLinear(dim, dim_out)
    self._linear_v = NLinear(dim, dim_out)

    self.rope = RotaryPositionalEmbeddings(dim_head, max_seq_length)

    self._scale_qk = Scale(dim_out, s_qk_scale, s_qk_init) # Maybe q & k should have separate layers but I don't see a reason why/ why not
     # (The paper says that there is no need for separate k & q scaling values)

     # TODO: However, there should be separate k & q per head!

    self._linear_out = NLinear(dim_out, dim)

    self.softmax_scale = dim_head ** 0.5
    self.is_causal = is_causal

  def split_heads(self, x):
    return x.view(x.shape[0], -1, self.n_heads, self.dim_head).transpose(1, 2)

  def merge_heads(self, x):
    batch_size, n_heads, seq_length, dim_head = x.shape
    x = x.transpose(1, 2).contiguous()
    return x.view(batch_size, seq_length, n_heads * dim_head)

  def forward(self, x):
    k, q, v = self._linear_k(x), self._linear_q(x), self._linear_v(x)

    k = self._scale_qk(k)
    q = self._scale_qk(q)

    k, q, v = self.split_heads(k), self.split_heads(q), self.split_heads(v)

    # Splitting destroys the norm. Let's re-norm.
    # NOTE: In the paper there is an ablation of whether or not to normalize this, and the effects are pretty much the same
    k = F.normalize(k, p=2, dim=-1)
    q = F.normalize(q, p=2, dim=-1)

    k, q = self.rope(k), self.rope(q)

    # In traditional transformers the softmax scaling factor is 1/sqrt(d_k) because the expected variance
    # in dot product of non-normalized key and query is d_k. In case on normalization, the expected variance
    # is 1/d_k and so the softmax scaling factor to bring the variance to 1 should be sqrt(d_k)
    # print(torch.norm(v, p=2, dim=-1))
    # print(self.softmax_scale)

    out = F.scaled_dot_product_attention(q, k, v, scale=self.softmax_scale, is_causal=self.is_causal)


    # Clearly, when calculating attention, the norm will most likely not be preserved. Because it is sort of independent
    # and tells how much each vector is similar to the other. So, we will have to normalize this output over its embedding dimension.
    out = self.merge_heads(out)

    # So what I am thinking the approach should be is to normalize the output on my own and then put it in the linear layer
    # "any update that causes the hidden state h to deviate from the manifold is followed by a normalization step"
    out = F.normalize(out, p=2, dim=-1)

    # Whereas, what lucidrains did is they put in the unnormalized output in the linear layer returns that and in nTransformer
    # they are actually normalizing the overall output of attention. I didn't like the coupling and also the shift of putting an
    # un normalized vector into the feed forward contradicts the main idea of the paper. On the same note, the paper doesn't
    # talk about where to normalize the output.
    out = self._linear_out(out)

    return F.normalize(out, p=2, dim=-1)



## Sanity Checks for NAttention

In [10]:
x = torch.randn(1, 1024, 512)
x = F.normalize(x, p=2, dim=-1)

In [11]:
attn_model = NAttention(dim=512)

In [12]:
x = attn_model(x)

In [13]:
x

tensor([[[-0.1062, -0.0041,  0.0865,  ...,  0.0297, -0.1023, -0.0483],
         [-0.0128, -0.0169,  0.0185,  ...,  0.0352,  0.0003, -0.0167],
         [-0.0350, -0.0469, -0.0436,  ...,  0.0538, -0.0168, -0.0389],
         ...,
         [-0.0639, -0.0653,  0.0408,  ...,  0.0189, -0.0413,  0.0031],
         [-0.0098, -0.0113,  0.0224,  ...,  0.0451, -0.0300, -0.0455],
         [-0.0127, -0.0812, -0.0406,  ...,  0.0816, -0.0388,  0.0592]]],
       grad_fn=<DivBackward0>)

In [14]:
torch.norm(x, dim=-1, p=2)

tensor([[1., 1., 1.,  ..., 1., 1., 1.]], grad_fn=<LinalgVectorNormBackward0>)

## Transformer

In [15]:
from typing import List

class nTransformer(nn.Module):
  def __init__(self,
               dim_in,
               dim_head=64,
               n_heads=8,
               max_seq_length=512,
               depth=5,
               is_causal=True,
               alpha_init: int | List = 1.0,
               alpha_ff_init: int | List = 1.0,
               alpha_attn_init: int | List = 1.0,
               alpha_ff_scale: int | List = 1.0,
               alpha_attn_scale: int | List = 1.0,
               s_qk_init: int | List = 1.0,
               s_qk_scale: int | List = 1.0,
               scale_gated: int | List = 1.0,
               scale_hidden: int | List = 1.0,
               expand_factor: int | List = 1.0,
               scale_hidden_init: int | List = 1.0,
               scale_gated_init: int | List = 1.0,
               ):
    super().__init__()

    # For all the depth
    poss_list_inputs = (
        expand_factor, n_heads, dim_head, max_seq_length, alpha_init, alpha_ff_init, alpha_attn_init, alpha_ff_scale, s_qk_init, s_qk_scale,
        scale_gated, scale_hidden, expand_factor, scale_hidden_init, scale_gated_init
    )

    def make_list(x):
      if not isinstance(x, list):
        return [x for _ in range(depth)]
      assert len(x) == depth
      return x

    poss_list_inputs = (make_list(i) for i in poss_list_inputs)

    self.layers = nn.ModuleList([])

    for expand_factor, n_heads, dim_head, max_seq_length, alpha_init, alpha_ff_init, alpha_attn_init, alpha_ff_scale, s_qk_init, \
    s_qk_scale, scale_gated, scale_hidden, expand_factor, scale_hidden_init, scale_gated_init in zip(*poss_list_inputs):
        attn_layer = NAttention(dim_in, n_heads, dim_head, max_seq_length, s_qk_init, s_qk_scale, is_causal)
        attn_lerp = Scale(dim_in, alpha_attn_scale, alpha_attn_init)
        dim_out = int(dim_in * expand_factor)
        ff_layer = NFeedForward(dim_in, dim_out, scale_gated, scale_hidden, scale_gated_init, scale_hidden_init)
        ff_lerp = Scale(dim_out, alpha_ff_scale, alpha_ff_init)
        dim_in = dim_out
        self.layers.append(nn.ModuleList([attn_layer, attn_lerp, ff_layer, ff_lerp])) # So the ModuleList actually holds like [at_1, at_lerp_1, ff_1, ff_lerp_1], [at_2, at_lerp_2, ff_2, ff_lerp_2], ...


  def forward(self, x):
    for attn_layer, attn_lerp, fn_layer, fn_lerp in self.layers:
      attn_out = attn_layer(x)
      x = attn_lerp(attn_out - x) + x # h <- h + \alpha_A(h_A - h)
      x = F.normalize(x, dim=-1, p=2)

      fn_out = fn_layer(x)
      x = fn_lerp(fn_out - x) + x
      x = F.normalize(x, dim=-1, p=2)

    return x
