<a href="https://colab.research.google.com/github/samitha278/CoreLlama/blob/main/llama2_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn

In [3]:
class LLaMA2Config():

    def __init__(self):

        self.n_embd = 4096
        self.n_layers = 32
        self.heads = 32           # for queries
        self.n_kv_heads = None    # for k v
        self.vocab_size = None
        self.n_hidden = None

        self.norm_eps = 1e-5

        self.max_batch_size = 32
        self.max_seq_len = 2048



In [4]:
class LlamaForCausalLM(nn.Module):

    def __init__(self,config):
        super().__init__()

        self.config = config

        self.model = LlamaModel(config)

        self.lm_head = nn.Linear(config.n_embd,config.vocab_size)


    def forward(self,tokens,start_pos):

        out = self.model(tokens,start_pos)

        logits = self.lm_head(out)

        return logits


In [5]:
def precompute_theta_pos_frequencies(head_dim,seq_len,device,theta=10000.0):

    assert head_dim % 2 == 0

In [6]:
class LlamaModel(nn.Module):

    def __init__(self,config):
        super().__init__()

        self.config = config

        self.embed_tokens = nn.Embedding(config.vocab_size,config.n_embd)

        self.layers = nn.ModuleList([LlamaDecorderLayer(config) for _ in range(config.n_layer)])

        self.norm = LlamaRMSNorm(config._n_embd)

        ##############
        self.freqs_complex = self.freqs_complex = precompute_theta_pos_frequencies(self.config.n_embd // self.config.n_heads, self.config.max_seq_len * 2, device=self.config.device)


    def forward(self,tokens,start_pos):

        B,T = tokens.shape

        embds = self.embeddings(tokens)

        ###############
        freqs_complex = self.freqs_complex[start_pos:start_pos + T]

        out = embds
        for layer in self.layers:
            out = layer(out,start_pos, freqs_complex)

        out_norm = self.norm(out)

        return out_norm

In [7]:
class LlamaDecorderLayer(nn.Module):

    def _init__(self,config):
        super().__init__()

        self.config = config

        self.input_layernorm = LlamaRMSNorm(config._n_embd)

        self.self_attn = LlamaAttention(config)

        self.post_attention_layernorm = LlamaRMSNorm(config._n_embd)

        self.mlp = LlamaMLP(config)


    def forward(self,embds,start_pos, freqs_complex):

        out = embds + self.attn(self.input_layernorm(embds))

        out = out + self.mlp(self.post_attention_layernorm(out))

        return out


In [8]:
class LlamaAttention(nn.Module):

    def __init__(self,config):
        super().__init__()


    def forward(self,x):
        pass


In [9]:
class LlamaMLP(nn.Module):

    def __init__(self,config):
        super().__init__()


    def forward(self,x):
        pass


In [10]:
class LlamaRMSNorm(nn.Module):

    def __init__(self,n_embd,norm_eps):
        super().__init__()
        self.norm_eps = norm_eps

        self.gamma = nn.Parameter(torch.ones(n_embd))

    def forward(self,x):

        rms = torch.sqrt(self.norm_eps + torch.mean(torch.pow(x,2.0),dim =-1,keepdim=True))

        x_norm = (x/rms) * self.gamma

        return x_norm
