In [1]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding
from torch.nn.utils.parametrize import register_parametrization

#parallel computing
from torch.distributed import init_process_group,destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP

class ModelConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    heads: int = 12
    n_layer: int = 12
    n_embd: int = 768
    bias: bool = False
    parametrize: bool = True
    factor: int = 4
    dropout: float=0.0


class AttentionConfig:
    d: int = -1
    groups: int = 1
    norm_eps: int = 0
    eps: float = 1e-6
    init_scale = 1
    scale: int = 1


class FFConfig:
    d: int = -1
    groups: int = 1
    norm_eps: int = 0
    eps: float = 1e-6
    init_scale: int = 1
    scale: int = 1


def exist(v):
    return v is not None


def default(v, d):
    return v if exist(v) else d


def l2Norm(x, d=-1, groups=1, eps=1e-6, norm_eps=0):
    eps = default(eps, 1e-5 if x.dtype == torch.float16 else 1e-10)

    if groups > 1:
        x = x.chunk(groups, dim=d)
        x = torch.stack(x)

    if norm_eps == 0:
        x_norm = F.normalize(x, dim=d, p=2, eps=eps)

    if norm_eps != 0:
        norm = x.norm(dim=d, keepdim=True)
        d_norm = norm.detach().clamp(min=1 - norm_eps, max=1 + norm_eps)
        divisor = norm / d_norm
        x_norm = x / divisor.clamp(min=eps)

    if groups > 1:
        x_norm = torch.cat([*x_norm], dim=d)

    return x_norm


class L2Norm(nn.Module):
    def __init__(self, d=-1, groups=1, eps=1e-6, norm_eps=0):
        super().__init__()
        self.d = d
        self.groups = groups
        self.eps = eps
        self.norm_eps = norm_eps

    def forward(self, x):
        return l2Norm(
            x, d=self.d, groups=self.groups, eps=self.eps, norm_eps=self.norm_eps
        )


class LinearNormWeight(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        parametrize=False,
        groups=1,
        d=-1,
        eps=1e-6,
        norm_eps=0,
        bias=False,
    ):
        super().__init__()
        self.scale = groups**-1
        self.parametrize = parametrize
        self.linear = nn.Linear(dim_in, dim_out, bias=bias)
        self.L2Norm = L2Norm(d, groups, eps, norm_eps)
        if parametrize:
            register_parametrization(self.linear, "weight", self.L2Norm)

        self.norm_weight_()

    @torch.no_grad()
    def norm_weight_(self):
        if self.parametrize:
            norm = self.weights
            original = self.linear.parametrizations.weight.original
            original.copy_(norm)
        else:
            self.weights.copy_(self.L2Norm(self.weights))

    @property
    def weights(self):
        return self.linear.weight

    def forward(self, x):
        return self.linear(x) * self.scale


class Scale(nn.Module):
    def __init__(self, dim, init_scale=1, scale=1):
        super().__init__()
        self.params = nn.Parameter(torch.ones(dim) * scale)
        self.divide_scale = init_scale / scale

    def forward(self):
        return self.params * self.divide_scale


class Attention(nn.Module):
    def __init__(self, args: ModelConfig, args_attn: AttentionConfig):
        super().__init__()
        self.args = args
        self.to_q = LinearNormWeight(
            args.n_embd,
            args.n_embd,
            args.parametrize,
            args_attn.groups,
            args_attn.d,
            args_attn.eps,
            args_attn.norm_eps,
        )
        self.to_k = LinearNormWeight(
            args.n_embd,
            args.n_embd,
            args.parametrize,
            args_attn.groups,
            args_attn.d,
            args_attn.eps,
            args_attn.norm_eps,
        )
        self.to_v = LinearNormWeight(
            args.n_embd,
            args.n_embd,
            args.parametrize,
            args_attn.groups,
            args_attn.d,
            args_attn.eps,
            args_attn.norm_eps,
        )

        self.dim_head = args.n_embd // args.heads
        self.n_heads = args.heads
        self.softmax_scale = self.dim_head**0.5
        self.q_scale = Scale(args.n_embd, 1, args.n_embd ** (-0.5))
        self.k_scale = Scale(args.n_embd, 1, args.n_embd ** (-0.5))
        self.rotary_embed=RotaryEmbedding(self.dim_head)
        self.flash=hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        self.dropout=args.dropout
        if not self.flash:
            self.register_buffer(
                "mask",
                torch.tril(
                    torch.ones(args.block_size, args.block_size).view(
                        1, 1, args.block_size, args.block_size
                    )
                ),)
            
        self.c_proj = LinearNormWeight(
            args.n_embd,
            args.n_embd,
            args.parametrize,
            args_attn.groups,
            args_attn.d,
            args_attn.eps,
            args_attn.norm_eps,
        )

    def forward(self, x):
        B, T, C = x.size()
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
        k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)
        v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2)

        q = self.rotary_embed.rotate_queries_or_keys(q)
        k = self.rotary_embed.rotate_queries_or_keys(k)
    
        q = q * rearrange(self.q_scale(), "(h d) -> h 1 d", h=self.n_heads)
        k = k * rearrange(self.q_scale(), "(h d) -> h 1 d", h=self.n_heads)
        if self.flash:
            attn=torch.nn.functional.scaled_dot_product_attention(q,k,v,attn_mask=None,dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            attn = q @ k.transpose(-1, -2)
    
            attn = attn * self.softmax_scale
    
            attn = attn.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
            attn = F.softmax(attn, dim=-1)
            attn = torch.matmul(attn, v)
        out = attn.transpose(1, 2).contiguous().view(B, T, C)

        return self.c_proj(out)


class FeedForward(nn.Module):
    def __init__(self, args: ModelConfig, args_ffn: FFConfig):
        super().__init__()
        hidden_dim = args.factor * args.n_embd
        self.w1 = LinearNormWeight(args.n_embd, hidden_dim)
        self.w2 = LinearNormWeight(hidden_dim, args.n_embd)
        self.w3 = LinearNormWeight(args.n_embd, hidden_dim)

        self.scale_u = Scale(
            hidden_dim, init_scale=args_ffn.init_scale, scale=args_ffn.scale
        )
        self.scale_v = Scale(
            hidden_dim, init_scale=args_ffn.init_scale, scale=args_ffn.scale
        )
        self.scale_ = hidden_dim**0.5

    def forward(self, x):
        u = self.w1(x)*self.scale_u()
        
        v = self.w3(x)*self.scale_v()

        v = v * self.scale_

        return self.w2(F.silu(v) * u)


class Lerp_Residual(nn.Module):
    def __init__(self, args: ModelConfig, index_layer, fc):
        super().__init__()
        self.fc = fc
        self.l2Norm = L2Norm(d=-1)
        self.scale = Scale(
            args.n_embd, init_scale=(0.05 / (index_layer+1)), scale=args.n_embd ** (-0.5)
        )

    def forward(self, x, **kwargs):
        connect_ = x
        out = self.l2Norm(self.fc(x, **kwargs))
        out = torch.lerp(connect_, out, self.scale())

        return self.l2Norm(out)


class nGPT(nn.Module):
    def __init__(
        self, args: ModelConfig, args_attn: AttentionConfig, args_ffn: FFConfig
    ):
        super().__init__()
        self.n_layer = args.n_layer
        self.n_attn_layeers = nn.ModuleList(
            [Attention(args, args_attn) for i in range(args.n_layer)]
        )
        self.n_ffn_layers = nn.ModuleList(
            [FeedForward(args, args_ffn) for i in range(args.n_layer)]
        )
        self.residual_attn = nn.ModuleList(
            [
                Lerp_Residual(args, i, self.n_attn_layeers[i])
                for i in range(args.n_layer)
            ]
        )
        self.residual_ffn = nn.ModuleList(
            [Lerp_Residual(args, i, self.n_ffn_layers[i]) for i in range(args.n_layer)]
        )
        self.to_logits = LinearNormWeight(args.n_embd, args.vocab_size)
        self.scale_logits=Scale(args.vocab_size,1,args.n_embd**-0.5)
        self.to_embedding=nn.Embedding(args.vocab_size,args.n_embd)
        self.block_size=args.block_size
    def forward(self, x,targets=None):
        
        x=self.to_embedding(x)
        B, T, C = x.size()
        for residual_attn, residual_ffn in zip(self.residual_attn, self.residual_ffn):
            x = residual_attn(x)
            x = residual_ffn(x)
        logits = (self.to_logits(x)*self.scale_logits())
        if targets is not None:
            loss=F.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1),ignore_index=-1)
        else: 
            loss=None

        return loss,logits

    @torch.no_grad()
    def generate(self,idx,max_new_tokens,temperature=1.0,top_k=None):
        for i in range(max_new_tokens):
            idx_cond=idx if idx.size(1) <self.block_size else idx[:,-self.block_size:]
            _,logits=self(idx_cond)
            logits=logits[:,-1,:]/temperature
            if top_k is not None:
                v,_=torch.topk(logits,min(top_k,logits.size(-1)))
                logits[logits<v[:,[-1]]]=-float('Inf')
            probs=F.softmax(logits,dim=-1)
            idx_next=torch.multinomial(probs,num_samples=1)
            idx=torch.cat((idx,idx_next),dim=1)
        return idx 

In [2]:
import torch
print(torch.__version__)


2.2.1


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
from torch.utils.data import Dataset, DataLoader

In [5]:
model=nGPT(ModelConfig,AttentionConfig,FFConfig).to(device)
# Load the model weights


In [6]:
print(hasattr(torch.nn.functional, 'scaled_dot_product_attention'))


True


In [7]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 190674432


In [2]:
!pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0
[0m

In [3]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting regex>=2022.1.18 (from tiktoken)
  Downloading regex-2024.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading regex-2024.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (782 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m782.7/782.7 kB[0m [31m41.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: regex, tiktoken
Successfully installed regex-2024.9.11 tiktoken-0.8.0
[0m

In [4]:
!pip install rotary_embedding_torch

Collecting rotary_embedding_torch
  Downloading rotary_embedding_torch-0.8.4-py3-none-any.whl.metadata (678 bytes)
Downloading rotary_embedding_torch-0.8.4-py3-none-any.whl (5.6 kB)
Installing collected packages: rotary_embedding_torch
Successfully installed rotary_embedding_torch-0.8.4
[0m

In [8]:
import tiktoken

In [9]:
test_data=torch.tensor([1.5,4.2,6.5,7.8,9.67])

In [10]:
output=l2norm(test_data)

NameError: name 'l2norm' is not defined

In [36]:
output

NameError: name 'output' is not defined

In [11]:
def collate_fn(batch):
    inputs, targets = zip(*batch)
    return torch.tensor(inputs), torch.tensor(targets)

In [12]:
class nGPTDataset(Dataset):
    def __init__(self,txt,tokenizer,block_size,stride):
        super().__init__()
        self.input_ds=[]
        self.target_ds=[]
        tokens_data=tokenizer.encode(txt,allowed_special={"<|endoftext|>"})
        for i in range(0,len(tokens_data)-block_size,stride):
            inputs=tokens_data[i:i+block_size]
            targets=tokens_data[i+1:i+block_size+1]
            self.input_ds.append(inputs)
            self.target_ds.append(targets)
    def __len__(self):
        return len(self.input_ds)

    def __getitem__(self,idx):
        return self.input_ds[idx],self.target_ds[idx]


def create_dataloader(txt,block_size=256,stride=128,batch_size=4,shuffle=True,drop_last=True,num_workers=0):
    tokenizer=tiktoken.get_encoding("gpt2")
    dataset=nGPTDataset(txt,tokenizer,block_size,stride)
    dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=shuffle,drop_last=drop_last,num_workers=num_workers,collate_fn=collate_fn)

    return dataloader
    
            

In [13]:
with open("input.txt","r") as f:
    data=f.read()

In [14]:
dataloader=create_dataloader(data,block_size=256,stride=256,batch_size=4,shuffle=True,drop_last=True,num_workers=0)

In [15]:
torch.set_float32_matmul_precision('high')

In [16]:
optimizer=torch.optim.AdamW(model.parameters(),lr=1e-3)


In [17]:
import time

In [None]:
loss_first=1.7
for i in range(1000):
    t0=time.time()
    optimizer.zero_grad()
    for inputs,targets in dataloader:
        inputs,targets=inputs.to(device),targets.to(device)
        
        with torch.autocast(device_type=device,dtype=torch.bfloat16):
            loss,logits=model(inputs,targets)
        loss.backward()
        norm=torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        optimizer.step()
        if device=="cuda":
            torch.cuda.synchronize()
        t1=time.time()
        dtime=t1-t0
        if loss_first>loss.item():
            torch.save(model.state_dict(), 'model.pth')
            loss_first=loss.item()
    print(f"loss: {loss.item()}: time per epoch{dtime}")

loss: 4.8282976150512695: time per epoch53.448368310928345
loss: 4.626946926116943: time per epoch53.04391622543335
loss: 4.311898708343506: time per epoch53.04680037498474
loss: 3.9512829780578613: time per epoch53.0349657535553
loss: 3.6069490909576416: time per epoch53.11629557609558
loss: 3.333874464035034: time per epoch53.05999422073364
loss: 3.0872554779052734: time per epoch52.99052286148071


In [101]:
@torch.no_grad()
def Generate(model,idx,max_new_tokens,context_size):
    for _ in range(max_new_tokens):
        idx_cond=ix if 

SyntaxError: incomplete input (1427814074.py, line 4)

In [None]:
!pip install rotary_embedding_torch

In [175]:
inputs="hello, how are you"

In [184]:
tokenizer=tiktoken.get_encoding("gpt2")
encoded=tokenizer.encode(inputs)
encoded_tensor=torch.tensor(encoded).unsqueeze(0)

In [185]:
encoded_tensor=encoded_tensor.to(device)

In [191]:
out=model.generate(encoded_tensor,max_new_tokens=10,temperature=1.0,top_k=5)

In [192]:
decoded_text=tokenizer.decode(out.squeeze(0).tolist())

In [193]:
print(decoded_text)

hello, how are youEdward Hag drumearancesthia freeingBonus Strategy Inquisitioniate


In [190]:
 decoded_text = tokenizer.decode(out.squeeze().tolist())

TypeError: argument 'tokens': 'list' object cannot be interpreted as an integer

In [None]:
# Assuming 'out' is a tensor
tokens = out.squeeze().tolist()  # Remove unnecessary dimensions and convert to list

# Decode tokens to text
decoded_text = tokenizer.decode(tokens)
