In [159]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from torch.nn.utils.parametrize import register_parametrization


class ModelConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    heads: int = 12
    n_layer: int = 12
    n_embd: int = 768
    bias: bool = False
    parametrize: bool = True
    factor: int = 4


class AttentionConfig:
    d: int = -1
    groups: int = 1
    norm_eps: int = 0
    eps: float = 1e-6
    init_scale = 1
    scale: int = 1


class FFConfig:
    d: int = -1
    groups: int = 1
    norm_eps: int = 0
    eps: float = 1e-6
    init_scale: int = 1
    scale: int = 1


def exist(v):
    return v is not None


def default(v, d):
    return v if exist(v) else d


def l2Norm(x, d=-1, groups=1, eps=1e-6, norm_eps=0):
    eps = default(eps, 1e-5 if x.dtype == torch.float16 else 1e-10)

    if groups > 1:
        x = x.chunk(groups, dim=d)
        x = torch.stack(x)

    if norm_eps == 0:
        x_norm = F.normalize(x, dim=d, p=2, eps=eps)

    if norm_eps != 0:
        norm = x.norm(dim=d, keepdim=True)
        d_norm = norm.detach().clamp(min=1 - norm_eps, max=1 + norm_eps)
        divisor = norm / d_norm
        x_norm = x / divisor.clamp(min=eps)

    if groups > 1:
        x_norm = torch.cat([*x_norm], dim=d)

    return x_norm


class L2Norm(nn.Module):
    def __init__(self, d=-1, groups=1, eps=1e-6, norm_eps=0):
        super().__init__()
        self.d = d
        self.groups = groups
        self.eps = eps
        self.norm_eps = norm_eps

    def forward(self, x):
        return l2Norm(
            x, d=self.d, groups=self.groups, eps=self.eps, norm_eps=self.norm_eps
        )


class LinearNormWeight(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        parametrize=False,
        groups=1,
        d=-1,
        eps=1e-6,
        norm_eps=0,
        bias=False,
    ):
        super().__init__()
        self.scale = groups**-1
        self.parametrize = parametrize
        self.linear = nn.Linear(dim_in, dim_out, bias=bias)
        self.L2Norm = L2Norm(d, groups, eps, norm_eps)
        if parametrize:
            register_parametrization(self.linear, "weight", self.L2Norm)

        self.norm_weight_()

    @torch.no_grad()
    def norm_weight_(self):
        if self.parametrize:
            norm = self.weights
            original = self.linear.parametrizations.weight.original
            original.copy_(norm)
        else:
            self.weights.copy_(self.L2Norm(self.weights))

    @property
    def weights(self):
        return self.linear.weight

    def forward(self, x):
        return self.linear(x) * self.scale


class Scale(nn.Module):
    def __init__(self, dim, init_scale=1, scale=1):
        super().__init__()
        self.params = nn.Parameter(torch.ones(dim) * scale)
        self.divide_scale = init_scale / scale

    def forward(self):
        return self.params * self.divide_scale


class Attention(nn.Module):
    def __init__(self, args: ModelConfig, args_attn: AttentionConfig):
        super().__init__()
        self.args = args
        self.to_q = LinearNormWeight(
            args.n_embd,
            args.n_embd,
            args.parametrize,
            args_attn.groups,
            args_attn.d,
            args_attn.eps,
            args_attn.norm_eps,
        )
        self.to_k = LinearNormWeight(
            args.n_embd,
            args.n_embd,
            args.parametrize,
            args_attn.groups,
            args_attn.d,
            args_attn.eps,
            args_attn.norm_eps,
        )
        self.to_v = LinearNormWeight(
            args.n_embd,
            args.n_embd,
            args.parametrize,
            args_attn.groups,
            args_attn.d,
            args_attn.eps,
            args_attn.norm_eps,
        )

        self.dim_head = args.n_embd // args.heads
        self.n_heads = args.heads
        self.softmax_scale = self.dim_head**0.5
        self.q_scale = Scale(args.n_embd, 1, args.n_embd ** (-0.5))
        self.k_scale = Scale(args.n_embd, 1, args.n_embd ** (-0.5))
        self.rotary_embed=RotaryEmbedding(self.dim_head)
        self.register_buffer(
            "mask",
            torch.tril(
                torch.ones(args.block_size, args.block_size).view(
                    1, 1, args.block_size, args.block_size
                )
            ),
        )
        self.c_proj = LinearNormWeight(
            args.n_embd,
            args.n_embd,
            args.parametrize,
            args_attn.groups,
            args_attn.d,
            args_attn.eps,
            args_attn.norm_eps,
        )

    def forward(self, x):
        B, T, C = x.size()
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
        k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
        v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)

        q = self.rotary_embed.rotate_queries_or_keys(q)
        k = self.rotary_embed.rotate_queries_or_keys(k)

        q = q * rearrange(self.q_scale(), "(h d) -> h 1 d", h=self.n_heads)
        k = k * rearrange(self.q_scale(), "(h d) -> h 1 d", h=self.n_heads)

        attn = q @ k.transpose(-1, -2)

        attn = attn * self.softmax_scale

        attn = attn.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = torch.matmul(attn, v)
        out = attn.transpose(1, 2).contiguous().view(B, T, C)

        return self.c_proj(out)


class FeedForward(nn.Module):
    def __init__(self, args: ModelConfig, args_ffn: FFConfig):
        super().__init__()
        hidden_dim = args.factor * args.n_embd
        self.w1 = LinearNormWeight(args.n_embd, hidden_dim)
        self.w2 = LinearNormWeight(hidden_dim, args.n_embd)
        self.w3 = LinearNormWeight(args.n_embd, hidden_dim)

        self.scale_u = Scale(
            hidden_dim, init_scale=args_ffn.init_scale, scale=args_ffn.scale
        )
        self.scale_v = Scale(
            hidden_dim, init_scale=args_ffn.init_scale, scale=args_ffn.scale
        )
        self.scale_ = hidden_dim**0.5

    def forward(self, x):
        u = self.w1(x)*self.scale_u()
        
        v = self.w3(x)*self.scale_v()

        v = v * self.scale_

        return self.w2(F.silu(v) * u)


class Lerp_Residual(nn.Module):
    def __init__(self, args: ModelConfig, index_layer, fc):
        super().__init__()
        self.fc = fc
        self.l2Norm = L2Norm(d=-1)
        self.scale = Scale(
            args.n_embd, init_scale=(0.05 / (index_layer+1)), scale=args.n_embd ** (-0.5)
        )

    def forward(self, x, **kwargs):
        connect_ = x
        out = self.l2Norm(self.fc(x, **kwargs))
        out = torch.lerp(connect_, out, self.scale())

        return self.l2Norm(out)


class nGPT(nn.Module):
    def __init__(
        self, args: ModelConfig, args_attn: AttentionConfig, args_ffn: FFConfig
    ):
        super().__init__()
        self.n_layer = args.n_layer
        self.n_attn_layeers = nn.ModuleList(
            [Attention(args, args_attn) for i in range(args.n_layer)]
        )
        self.n_ffn_layers = nn.ModuleList(
            [FeedForward(args, args_ffn) for i in range(args.n_layer)]
        )
        self.residual_attn = nn.ModuleList(
            [
                Lerp_Residual(args, i, self.n_attn_layeers[i])
                for i in range(args.n_layer)
            ]
        )
        self.residual_ffn = nn.ModuleList(
            [Lerp_Residual(args, i, self.n_ffn_layers[i]) for i in range(args.n_layer)]
        )
        self.to_logits = nn.Linear(args.n_embd, args.vocab_size)
        self.to_embedding=nn.Embedding(args.vocab_size,args.n_embd)
        
    def forward(self, x,targets=None):
        
        x=self.to_embedding(x)
        B, T, C = x.size()
        for residual_attn, residual_ffn in zip(self.residual_attn, self.residual_ffn):
            x = residual_attn(x)
            x = residual_ffn(x)
        logits = self.to_logits(x)
        if targets is not None:
            loss=F.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1),ignore_index=-1)
        else: 
            loss=None

        return loss,logits


In [160]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [161]:
from torch.utils.data import Dataset, DataLoader

In [162]:
model=nGPT(ModelConfig,AttentionConfig,FFConfig).to(device)

In [163]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 190674432


In [81]:
!pip install einops

[0m

In [82]:
!pip install tiktoken

[0m

In [83]:
import tiktoken

In [112]:
def collate_fn(batch):
    inputs, targets = zip(*batch)
    return torch.tensor(inputs), torch.tensor(targets)

In [164]:
class nGPTDataset(Dataset):
    def __init__(self,txt,tokenizer,block_size,stride):
        super().__init__()
        self.input_ds=[]
        self.target_ds=[]
        tokens_data=tokenizer.encode(txt,allowed_special={"<|endoftext|>"})
        for i in range(0,len(tokens_data)-block_size,stride):
            inputs=tokens_data[i:i+block_size]
            targets=tokens_data[i+1:i+block_size+1]
            self.input_ds.append(inputs)
            self.target_ds.append(targets)
    def __len__(self):
        return len(self.input_ds)

    def __getitem__(self,idx):
        return self.input_ds[idx],self.target_ds[idx]


def create_dataloader(txt,block_size=256,stride=128,batch_size=4,shuffle=True,drop_last=True,num_workers=0):
    tokenizer=tiktoken.get_encoding("gpt2")
    dataset=nGPTDataset(txt,tokenizer,block_size,stride)
    print(len(dataset.input_ds[1]))
    dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=shuffle,drop_last=drop_last,num_workers=num_workers,collate_fn=collate_fn)

    return dataloader
    
            

In [165]:
with open("input.txt","r") as f:
    data=f.read()

In [166]:
dataloader=create_dataloader(data,block_size=256,stride=256,batch_size=4,shuffle=True,drop_last=True,num_workers=0)

256


In [167]:
optimizer=torch.optim.AdamW(model.parameters(),lr=1e-3)


In [None]:
for i in range(1000):
    for inputs,targets in dataloader:
        optimizer.zero_grad()
        inputs,targets=inputs.to(device),targets.to(device)
        loss,logits=model(inputs,targets)
        loss.backward()
        optimizer.step()
        print(loss.item())

10.824902534484863
10.809331893920898
10.796067237854004
10.78587532043457
10.776914596557617
10.773719787597656
10.751914978027344
10.746347427368164
10.730432510375977
10.7125825881958
10.698067665100098
10.683918952941895
10.664010047912598
10.640602111816406
10.6179780960083
10.607752799987793
10.59593391418457
10.55465316772461
10.545076370239258
10.524458885192871
10.51523208618164
10.48146915435791
10.473435401916504
10.450030326843262
10.405782699584961
10.379457473754883
10.366362571716309
10.310890197753906
10.31900405883789
10.259428024291992
10.24765396118164
10.176809310913086
10.195982933044434
10.157027244567871
10.110694885253906
10.087054252624512
10.040319442749023
10.02815055847168
9.956095695495605
9.893836975097656
9.907533645629883
9.919057846069336
9.869524002075195
9.774945259094238
9.761590957641602
9.675058364868164
9.745437622070312
9.6355619430542
9.601641654968262
9.550468444824219
9.571961402893066
9.493193626403809
9.424906730651855
9.385324478149414
9.38

In [4]:
!pip install rotary_embedding_torch

Collecting rotary_embedding_torch
  Downloading rotary_embedding_torch-0.8.4-py3-none-any.whl.metadata (678 bytes)
Downloading rotary_embedding_torch-0.8.4-py3-none-any.whl (5.6 kB)
Installing collected packages: rotary_embedding_torch
Successfully installed rotary_embedding_torch-0.8.4
[0m