<a href="https://colab.research.google.com/github/varadasantosh/deep-learning-notes/blob/tensorflow/Rotary_Embeddings_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
import torch

def determine_rotation_theta(max_seqlen, d_model):

  """ This method takes Sequence Length , Dimensions of Embeddings to calculate the angle for
      each position in the sequence
  """
  theta = 1/torch.pow(10000,torch.arange(0,d_model,2)/d_model)
  positions = torch.arange(0,max_seqlen)
  position_theta = positions.unsqueeze(1) * theta.unsqueeze(0)
  position_theta = torch.stack((position_theta.cos(),position_theta.sin()),dim=2).flatten(1)
  return  position_theta

def calc_rotary_embeddings(embeddings):

  batch_size,max_seqlen,d_model= embeddings.shape
  rotation_theta = determine_rotation_theta(max_seqlen,d_model)
  cos_theta = rotation_theta[...,0::2]
  sin_theta = rotation_theta[...,1::2]

  embeddings[...,0::2] =  embeddings[...,0::2] * cos_theta  + embeddings[...,1::2] * sin_theta
  embeddings[...,1::2] =  embeddings[...,0::2] * sin_theta  + embeddings[...,1::2] * cos_theta
  return embeddings




In [29]:
embeddings= torch.randn(1,4,8)
rotated_embeddings = calc_rotary_embeddings(embeddings)
rotated_embeddings.shape

torch.Size([1, 4, 8])

# determine_rotation_theta

Step 1:- This line of code is to calculate θ for each position , the formaule
          for calculating θ = 10000^(2i/d)
   
    theta = 1/torch.pow(10000,torch.arange(0,d_model,2)/d_model)

Step 2:- Construct Positions tensor which corresponds to each token in a
          sequence   

    positions = torch.arange(0,max_seqlen)

Step 3:- for each token and each embedding pair the angle needs to be  
         calculated , the final angle depends on the position of the token
         and index of an embedding pair , meaning if the sequence is of lenght
         **4** and Embeddings are of length **8**, because we rotate the embedings in 2D space, hence from the maths and theory we looked into
         we rotate pair of embeddings hence we will have embeddings of pairs
         (0,1) (2,3) (4,5) (6,7) for all tokens, below code helps in doing the same , positions.unsqueeze results in dimension (4*1), theta.unsqueeze(0) results in 1*4 hence the final result would be of dimension (4*4)

    position_theta = positions.unsqueeze(1) * theta.unsqueeze(0)     

Step 4:- We need to calculate the cosθ and sinθ required for all θ for all the
         tokens and embedding pairs , this is one of the important part of calcualtion, to understand this bit more let us take an example

    position_theta = torch.stack((position_theta.cos(),position_theta.sin()),dim=2).flatten(1)      

   

In [41]:
cos= torch.arange(1,9).reshape(2,4)
sin= torch.arange(9,17).reshape(2,4)

# Stack Operation

  Stack Operation by default does stacking along zeroth dimension, but we would need to calculate cosθ & sinθ for all the angles (each token position and each embedding pair), hence we need to stack this along the last dimenstion, which is **2** , we will see the effect of this

In [36]:
torch.stack((cos,sin))

tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8]],

        [[ 9, 10, 11, 12],
         [13, 14, 15, 16]]])

In [42]:
torch.stack((cos,sin)).shape

torch.Size([2, 2, 4])

# Stacking along the last dimension, but this results in different shape from we need, hence we need to flatten this, we can observe the shape before and after  flattening

In [37]:
torch.stack((cos,sin),dim=2)

tensor([[[ 1,  9],
         [ 2, 10],
         [ 3, 11],
         [ 4, 12]],

        [[ 5, 13],
         [ 6, 14],
         [ 7, 15],
         [ 8, 16]]])

In [39]:
torch.stack((cos,sin),dim=2).shape

torch.Size([2, 4, 2])

# Let us Flatten this from dimension **1** , flattening this changes it to below and we can observe the shape , in below , this is required as for each pair we need , though we have 4 pairs, for each pair we need to calculate both cosθ & sinθ, to summarize we started with cos & sin matrices of size(2,4) which corresponds to 2 tokens and each token has embedding dimension of size 8, as 8 results in 4 pairs we have cos & sin of size (2,4) for each embedding pair we need to calculate cos & sin values which results tokens*pairs*2 =2*4*2 =16



##PAIR-1

|   DIM-1    |   DIM-2    |
|------------|------------|
| 1 (COSθ)   | 2 (SINθ)   |
| 9 (SINθ)   | 10 (SINθ)  |

##PAIR-2

|   DIM-1    |   DIM-2    |
|------------|------------|
| 3 (COSθ)   | 4 (SINθ)   |
| 11 (SINθ)  | 12 (SINθ)  |

##PAIR-3

|   DIM-1    |   DIM-2    |
|------------|------------|
| 5 (COSθ)   | 6 (SINθ)   |
| 13 (SINθ)  | 14 (SINθ)  |

##PAIR-4

|   DIM-1    |   DIM-2    |
|------------|------------|
| 7 (COSθ)   | 8 (SINθ)   |
| 15 (SINθ)  | 16 (SINθ)  |





In [38]:
torch.stack((cos,sin),dim=2).flatten(1)

tensor([[ 1,  9,  2, 10,  3, 11,  4, 12],
        [ 5, 13,  6, 14,  7, 15,  8, 16]])

In [40]:
torch.stack((cos,sin),dim=2).flatten(1).shape

torch.Size([2, 8])

# calc_rotary_embeddings

Step 1:- Find the dimensions of Batch, Sequence Length, Embedding Dimension

    batch_size,max_seqlen,d_model= embeddings.shape

Step 2:- Determine the rotation angles(θ) required for each position and
         embedding pair using the method defined for the same which takes
         sequence length & embedding dimensions as input

    rotation_theta = determine_rotation_theta(max_seqlen,d_model)

Step 3:- Get cosθ and sinθ for each token embedding pair, using below lines of
         code , we can see the below to understand the indexes which corresponds to different embedding dimensions

    cos_theta = rotation_theta[...,0::2]
    sin_theta = rotation_theta[...,1::2]     

Step 4:- Finally calculate the Rotation angle for each dimension using below

    embeddings[...,0::2] =  embeddings[...,0::2] * cos_theta  + embeddings[...,1::2] * sin_theta
    
    embeddings[...,1::2] =  embeddings[...,0::2] * sin_theta  + embeddings[...,1::2] * cos_theta

   for instance embeddings[0] gives embedding at 0 , embedding at 1 gives embedding at 1 position , now we apply this along with our rotation matrix where x,y corresponds to 0,1

   x' = x cosθ - y sinθ \
   y' = x sinθ + y cosθ

   R(0) = embedding[0] * cosθ  - embedding[1] * sinθ
   R(1) = embedding[0] * sinθ  + embedding[1] * cosθ

   $$
\begin{bmatrix} x' \\
y' \end{bmatrix} =
\begin{bmatrix} \cos\theta & -\sin\theta \\
\sin\theta & \cos\theta \end{bmatrix}
\begin{bmatrix} x \\
y \end{bmatrix}
$$

In [52]:
rotation_theta = torch.stack((cos,sin),dim=2).flatten(1)


In [53]:
rotation_theta[...,0::2]

tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

In [54]:
rotation_theta[...,1::2]

tensor([[ 9, 10, 11, 12],
        [13, 14, 15, 16]])