In [2]:
from torch import nn
from torch import Tensor
import math
import torch


class RMSNorm(nn.Module):
    def __init__(self, input_size: int, prob: float=1.0, eps: float=1e-8):
        super().__init__()
        assert prob > 0.0

        self.input_size = input_size
        self.gamma = nn.Parameter(torch.ones(input_size), requires_grad=True)
        self.prob = prob
        self.eps = eps

    def forward(self, input_: Tensor) -> Tensor:
        if self.prob >= 1.0:
            input_norm = input_.norm(2, dim=-1, keepdim=True)
            den = input_norm / math.sqrt(self.input_size) + self.eps
            return (input_ / den) * self.gamma

        est_size = int(self.prob * self.input_size)
        to_calc = input_[..., :est_size]
        est_norm = to_calc.norm(2, dim=-1, keepdim=True)
        den = est_norm / torch.sqrt(est_size) + self.eps
        return (input_ / den) * self.gamma
        

In [3]:
import torch
from torch import nn


class RoPE:
    def __init__(self, seq_len: int, embed_size: int, device: torch.device):
        self.base = 10000
        self.seq_len = seq_len
        self.embed_size = embed_size
        arange = torch.arange(0, self.embed_size, 2, device=device)
        theta = 1.0 / (self.base ** (arange / embed_size))
        idxs = torch.arange(seq_len, device=device)

        outer_product = torch.outer(theta, idxs)
        self.storage = torch.stack((torch.cos(outer_product), torch.sin(outer_product)), dim=-1).to(device)

    def apply(self, input_: torch.Tensor):
        batch_size, seq_len, nheads, hidden_size = input_.shape
        # cur_storage = self.storage[:seq_len]
        x = input_.view(batch_size, seq_len, nheads, hidden_size // 2, 2)
        cur_storage = self.storage.view(1, seq_len, 1, hidden_size // 2, 2)
        
        cos_theta = cur_storage[..., 0]
        sin_theta = cur_storage[..., 1]
        output = torch.stack([
            x[..., 0] * cos_theta - x[..., 1] * sin_theta,
            x[..., 0] * sin_theta + x[..., 1] * cos_theta,
        ], dim=-1)

        return output.view(batch_size, seq_len, nheads, hidden_size)
    

In [4]:
from torch import nn
from torch import Tensor
import torch


class Swish(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.beta = nn.Parameter(torch.tensor([1.0]), requires_grad=True)
        self.sigmoid = nn.Sigmoid()
    def forward(self, input_: Tensor) -> Tensor:
        return input_ * self.sigmoid(self.beta * input_)


class SwiGLUFeedForward(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, device):
        super().__init__()
        self.U = nn.Linear(input_size, hidden_size, bias=True, device=device)
        self.W = nn.Linear(input_size, hidden_size, bias=True, device=device)
        self.V = nn.Linear(hidden_size, input_size, bias=True, device=device)
        self.swish = Swish(device)

    def forward(self, input_: Tensor) -> Tensor:
        x1 = self.swish(self.U(input_)) # batch_size, seq_len, hidden_dim * c
        x2 = self.W(input_) # batch_size, seq_len, hidden_dim * c
        x = self.V(x1 * x2)
        return x # batch_size, seq_len, hidden_dim


In [5]:
import torch
import math
from torch import nn

class MultiHeadAttention(nn.Module):
    def __init__(self, seq_len: int, hidden_dim: int, n_head: int, rope: RoPE, device: torch.device):
        super().__init__()

        self.n_head = n_head

        self.w_q = nn.Linear(hidden_dim, hidden_dim, bias=False, device=device)
        self.w_k = nn.Linear(hidden_dim, hidden_dim, bias=False, device=device)
        self.w_v = nn.Linear(hidden_dim, hidden_dim, bias=False, device=device)

        self.shuffler = nn.Linear(hidden_dim, hidden_dim, bias=False, device=device)
        self.rope = RoPE(seq_len, hidden_dim // n_head, device)

    def forward(self, input_: torch.Tensor, mask: bool=True):
        batch_size, seq_len, hidden_dim = input_.shape
        head_size = hidden_dim // self.n_head

        Q = self.w_q(input_).view(batch_size, seq_len, self.n_head, head_size)
        K = self.w_k(input_).view(batch_size, seq_len, self.n_head, head_size)
        V = self.w_v(input_).view(batch_size, seq_len, self.n_head, head_size)

        Q = self.rope.apply(Q).transpose(1, 2)
        K = self.rope.apply(K).transpose(1, 2)
        V = V.transpose(1, 2)

        coefs = torch.softmax(torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(head_size), dim=-1) # batch_size, nhead, seq_len, seq_len
        if mask:
            coefs = torch.tril(coefs)

        output = torch.matmul(coefs, V) # batch_size, nhead, seq_len, head_size
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_dim)
        output = self.shuffler(output)

        return output # batch_size, seq_len, hidden_dim





In [6]:
import torch
from torch import nn


class LLaMaBlock(nn.Module):
    def __init__(self, seq_len:int, hidden_size: int, n_head: int, rope: RoPE, device: torch.device):
        super().__init__()
        self.rms_attn = RMSNorm(hidden_size)
        self.attention = MultiHeadAttention(seq_len, hidden_size, n_head, rope, device)
        self.rms_ffn = RMSNorm(hidden_size)
        self.swiglu = SwiGLUFeedForward(hidden_size, 8 * hidden_size // 3, device)

    def forward(self, input_: torch.Tensor) -> torch.Tensor:
        x = self.attention(self.rms_attn(input_)) # batch_size, seq_len, hidden_dim
        x = input_ + x 
        x = x + self.swiglu(self.rms_ffn(x)) 
        return x # batch_size, seq_len, hidden_dim


class LLaMa(nn.Module):
    def __init__(self, vocab_size:int, n_stacks:int, seq_len: int, hidden_size: int, n_head: int, device: torch.device):
        super().__init__()
        rope = RoPE(seq_len, hidden_size // n_head, device)
        self.n_stacks = n_stacks
        self.embed = nn.Embedding(vocab_size, hidden_size, device=device)
        self.rmsnorm = RMSNorm(hidden_size)
        self.blocks = nn.Sequential()
        for i in range(n_stacks):
            self.blocks.add_module(f"LLaMa Block {i}", LLaMaBlock(seq_len, hidden_size, n_head, rope, device))

        self.output_linear = nn.Linear(hidden_size, vocab_size, bias=True, device=device)

    def forward(self, input_: torch.Tensor):
        x = self.embed(input_) # batch_size, seq_len, hidden_dim
        for block in self.blocks:
            x = block(x)

        x = self.rmsnorm(x) 
        x = self.output_linear(x) 
        return x # batch_size, seq_len, vocab_size

    def __str__(self):
        """
        Model prints with the number of parameters.
        """
        all_parameters = sum([p.numel() for p in self.parameters()])
        trainable_parameters = sum(
            [p.numel() for p in self.parameters() if p.requires_grad]
        )

        result_info = super().__str__()
        result_info = result_info + f"\nAll parameters: {all_parameters}"
        result_info = result_info + f"\nTrainable parameters: {trainable_parameters}"

        return result_info

    



In [7]:
# %pip install datasets
# %pip install wandb
# %pip install transformers

In [8]:
import wandb
wandb.login(key='')
run = wandb.init()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: timtorbakhov111 (popegorov). Use `wandb login --relogin` to force relogin
wandb: Appending key for api.wandb.ai to your netrc file: /home/jupyter/.netrc
wandb: Tracking run with wandb version 0.18.6
wandb: Run data is saved locally in /home/jupyter/work/resources/wandb/run-20241109_131105-mpmuoqkg
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run winter-rain-5
wandb: ⭐️ View project at https://wandb.ai/popegorov/uncategorized
wandb: 🚀 View run at https://wandb.ai/popegorov/uncategorized/runs/mpmuoqkg


In [9]:
from torch.utils.data import DataLoader, Dataset

import torch

from datasets import load_dataset
from huggingface_hub import login
from transformers import AutoTokenizer

import torch.nn.functional as F
from tqdm.notebook import tqdm
import numpy as np


def tokenize(batch):
    return tokenizer(batch['text'])

class MyDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.tokenized_data = dataset.map(tokenize, batched=True, remove_columns=['text'])['input_ids']

    def __getitem__(self, index):
        return self.tokenized_data[index]

    def __len__(self):
        return len(self.tokenized_data)


def compute_loss(criterion, logits: torch.Tensor, labels: torch.Tensor, pad_id):
    logits = logits.reshape(-1, 32000)
    labels = labels.view(-1)

    loss = criterion(logits, labels, ignore_index=pad_id)
    return loss


counter = 0
PAD_ID = 2
MAX_SEQ_LEN = 256
login(token="")
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1')

    
dataset_id = "ashaba1in/small_openwebtext"
dataset = load_dataset(dataset_id)

print("Dataset size", len(dataset['train']['text']))
# print(dataset['train']['text'][0])
tokenizer_id ="mistralai/Mistral-7B-v0.1"


if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

PAD_ID = tokenizer.eos_token_id









Dataset size 1000000


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
def collate_fn(batch):
    flatten = []
    for text in batch:
        flatten += text
    target_size = len(batch) * MAX_SEQ_LEN
    if len(flatten) < target_size:
        flatten += [PAD_ID] * (target_size - len(flatten))
    else:
        flatten = flatten[:target_size]
    
    flatten = torch.Tensor(flatten).to(torch.long).view(len(batch), MAX_SEQ_LEN)
    return flatten


def train_epoch(model, criterion, optimizer, train_loader, epoch, pad_id, log_step, save_period, run):
    loss_log = []
    model.train()
    counter = 0
    log_count = 0
    for data in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
        data = data.to(device) # batch_size, seq_len
        counter += 1
        optimizer.zero_grad()
        out = model(data) # batch_size, seq_len, vocab_size
        loss = compute_loss(criterion, out[:, :-1], data[:, 1:].clone(), pad_id)
        loss_log.append(loss.item())
        
        loss.backward()
        optimizer.step()
        if not counter % log_step:
            print(f"Loss from {log_count * log_step} to {(log_count + 1) * log_step} step", np.mean(loss_log[log_count * log_step: (log_count + 1) * log_step]))
            run.log({'loss': np.mean(loss_log[log_count * log_step: (log_count + 1) * log_step])})
            log_count += 1
            
        if not counter % save_period:
            torch.save(model, "best_model.pth")
            wandb.save("best_model.pth")

    return loss_log

def train(model, criterion, optimizer, n_epochs, pad_id, log_step, save_period, run, dataset):
    print(model)
    len_epoch = 100000
    for epoch in range(n_epochs):
        cropped_dataset = dataset['train'].select(range(epoch * len_epoch, (epoch + 1) * len_epoch))
        my_data = MyDataset(cropped_dataset)
        train_loader = DataLoader(my_data, batch_size=8, collate_fn=collate_fn)
        
        train_loss = train_epoch(model, criterion, optimizer, train_loader, epoch, pad_id, log_step, save_period, run)
        print(f"Train loss: {np.mean(train_loss)}")

In [12]:
# from tqdm import tqdm

MAX_SEQ_LEN = 1024
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab_size = tokenizer.vocab_size
model = LLaMa(vocab_size, 8, MAX_SEQ_LEN, 768, 8, device).to(device)
criterion = F.cross_entropy
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)
n_epochs = 10
log_step = 20
save_period = 5000
train(model, criterion, optimizer, n_epochs, PAD_ID, log_step, save_period, run, dataset)

LLaMa(
  (embed): Embedding(32000, 768)
  (rmsnorm): RMSNorm()
  (blocks): Sequential(
    (LLaMa Block 0): LLaMaBlock(
      (rms_attn): RMSNorm()
      (attention): MultiHeadAttention(
        (w_q): Linear(in_features=768, out_features=768, bias=False)
        (w_k): Linear(in_features=768, out_features=768, bias=False)
        (w_v): Linear(in_features=768, out_features=768, bias=False)
        (shuffler): Linear(in_features=768, out_features=768, bias=False)
      )
      (rms_ffn): RMSNorm()
      (swiglu): SwiGLUFeedForward(
        (U): Linear(in_features=768, out_features=2048, bias=True)
        (W): Linear(in_features=768, out_features=2048, bias=True)
        (V): Linear(in_features=2048, out_features=768, bias=True)
        (swish): Swish(
          (sigmoid): Sigmoid()
        )
      )
    )
    (LLaMa Block 1): LLaMaBlock(
      (rms_attn): RMSNorm()
      (attention): MultiHeadAttention(
        (w_q): Linear(in_features=768, out_features=768, bias=False)
        (w_k)

Training Epoch 0:   0%|          | 0/12500 [00:00<?, ?it/s]

Loss from 0 to 20 step 8.42400221824646
Loss from 20 to 40 step 7.0235237121582035
Loss from 40 to 60 step 6.822653651237488
Loss from 60 to 80 step 6.586639356613159
Loss from 80 to 100 step 6.394759893417358
Loss from 100 to 120 step 6.401404309272766
Loss from 120 to 140 step 6.304096984863281
Loss from 140 to 160 step 6.292648983001709
Loss from 160 to 180 step 6.172917222976684
Loss from 180 to 200 step 6.16012670993805
Loss from 200 to 220 step 6.07097475528717
Loss from 220 to 240 step 6.004295539855957
Loss from 240 to 260 step 5.981633973121643
Loss from 260 to 280 step 5.970033144950866
Loss from 280 to 300 step 6.015690159797669
Loss from 300 to 320 step 5.854319095611572
Loss from 320 to 340 step 6.023437714576721
Loss from 340 to 360 step 5.953654384613037
Loss from 360 to 380 step 5.883006310462951
Loss from 380 to 400 step 5.7898475408554075
Loss from 400 to 420 step 5.902785301208496
Loss from 420 to 440 step 6.053717994689942
Loss from 440 to 460 step 5.850425219535827

Map: 100%|██████████| 100000/100000 [01:59<00:00, 837.65 examples/s]


Training Epoch 1:   0%|          | 0/12500 [00:00<?, ?it/s]

Loss from 0 to 20 step 0.33669422268867494
Loss from 20 to 40 step 0.33777383863925936
Loss from 40 to 60 step 0.35498461723327634
Loss from 60 to 80 step 0.3584600403904915
Loss from 80 to 100 step 0.354111148416996
Loss from 100 to 120 step 0.3348711460828781
Loss from 120 to 140 step 0.33628710359334946
Loss from 140 to 160 step 0.3861187607049942
Loss from 160 to 180 step 0.36915919184684753
Loss from 180 to 200 step 0.3411905601620674
Loss from 200 to 220 step 0.3465409651398659
Loss from 220 to 240 step 0.34457228034734727
Loss from 240 to 260 step 0.33218675702810285
Loss from 260 to 280 step 0.33961435556411745
Loss from 280 to 300 step 0.3350308120250702
Loss from 300 to 320 step 0.3546433448791504
Loss from 320 to 340 step 0.3483290195465088
Loss from 340 to 360 step 0.34993422478437425
Loss from 360 to 380 step 0.44374899864196776
Loss from 380 to 400 step 0.3405505523085594
Loss from 400 to 420 step 0.32917495667934416
Loss from 420 to 440 step 0.35489716827869416
Loss from

Map: 100%|██████████| 100000/100000 [01:59<00:00, 834.42 examples/s]


Training Epoch 2:   0%|          | 0/12500 [00:00<?, ?it/s]

Loss from 0 to 20 step 0.21620865762233735
Loss from 20 to 40 step 0.23154037520289422
Loss from 40 to 60 step 0.23773094341158868
Loss from 60 to 80 step 0.22272755578160286
Loss from 80 to 100 step 0.2311197392642498
Loss from 100 to 120 step 0.22758381590247154
Loss from 120 to 140 step 0.2201256826519966
Loss from 140 to 160 step 0.24708592891693115
Loss from 160 to 180 step 0.22181706205010415
Loss from 180 to 200 step 0.21560477912425996
Loss from 200 to 220 step 0.22340274453163148
Loss from 220 to 240 step 0.23231444433331488
Loss from 240 to 260 step 0.23521733433008193
Loss from 260 to 280 step 0.24628016129136085
Loss from 280 to 300 step 0.22454896569252014
Loss from 300 to 320 step 0.2787293650209904
Loss from 320 to 340 step 0.20748767703771592
Loss from 340 to 360 step 0.20392531231045724
Loss from 360 to 380 step 0.25865465477108956
Loss from 380 to 400 step 0.22991502285003662
Loss from 400 to 420 step 0.24059441909193993
Loss from 420 to 440 step 0.20755440816283227
L

Map: 100%|██████████| 100000/100000 [01:58<00:00, 845.00 examples/s]


Training Epoch 3:   0%|          | 0/12500 [00:00<?, ?it/s]

Loss from 0 to 20 step 0.10285154841840267
Loss from 20 to 40 step 0.11080565750598907
Loss from 40 to 60 step 0.3078573767095804
Loss from 60 to 80 step 0.10265103206038476
Loss from 80 to 100 step 0.22896673195064068
Loss from 100 to 120 step 0.11339891590178013
Loss from 120 to 140 step 0.118510040640831
Loss from 140 to 160 step 0.10684010088443756
Loss from 160 to 180 step 0.10796084329485893
Loss from 180 to 200 step 0.10797228924930095
Loss from 200 to 220 step 0.12071321345865726
Loss from 220 to 240 step 0.10634805709123611
Loss from 240 to 260 step 0.144910204783082
Loss from 260 to 280 step 0.1301879409700632
Loss from 280 to 300 step 0.1127145305275917
Loss from 300 to 320 step 0.09767359867691994
Loss from 320 to 340 step 0.1087566576898098
Loss from 340 to 360 step 0.10986167304217816
Loss from 360 to 380 step 0.11803590245544911
Loss from 380 to 400 step 0.13229920230805875
Loss from 400 to 420 step 0.18985508047044278
Loss from 420 to 440 step 0.12711454294621943
Loss f

KeyboardInterrupt: 

In [10]:
torch.save(model, "best_model.pth")


In [11]:
wandb.save("best_model.pth")

['/home/jupyter/work/resources/wandb/run-20241109_131105-mpmuoqkg/files/best_model.pth']