# Preparing repo :D

In [1]:
%%writefile model.py

import os
import random
import math


import torch
import torch.nn as nn
import torch.nn.functional as F


from einops import rearrange, reduce, repeat
from torch import einsum 

import numpy as np



class GPT(nn.Module):
    def __init__(self,
                 vocab_size,
                 num_layers,
                 num_heads,
                 hidden_dim,
                 ffc_hidden_dim,
                 attn_dropout_p=0.1,
                 ffc_dropout_p=0.1,
                 max_seq_len=512,
                 ):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.ffc_hidden_dim = ffc_hidden_dim
        self.attn_dropout_p = attn_dropout_p
        self.ffc_dropout_p = ffc_dropout_p
        self.max_seq_len = max_seq_len

        self.decoder_block = nn.ModuleList([DecoderLayer(self.num_heads,
                                                         self.hidden_dim,
                                                         self.ffc_hidden_dim,
                                                         self.attn_dropout_p,
                                                         self.ffc_dropout_p) for _ in range(self.num_layers)])
        
        self.pos_embeddings = nn.Embedding(self.max_seq_len, self.hidden_dim)
        self.token_embeddings = nn.Embedding(self.vocab_size, self.hidden_dim)

        self.proj_layer = nn.Linear(self.hidden_dim, self.vocab_size)
        
        self.register_buffer('tril',
                             torch.tril(torch.ones(self.max_seq_len, self.max_seq_len)).bool())
        self.register_buffer('pos_ids',
                             torch.arange(self.max_seq_len))
    


    def forward(self,
                input_tokens,
                tokenizer_mask=None):
        seq_len = input_tokens.shape[-1]
        b_size = input_tokens.shape[0]
        
        mask = self.make_attn_mask(seq_len, b_size, tokenizer_mask)
        
        x = self.pos_embeddings(self.pos_ids[:seq_len]) + self.token_embeddings(input_tokens)

        for layer in self.decoder_block:
            x = layer(x)
        x = self.proj_layer(x)
        return x


    def make_attn_mask(self, seq_len, b_size, tokenizer_mask=None):
        mask = self.tril[:seq_len, :seq_len].unsqueeze(0).repeat(b_size, 1, 1)
        
        if tokenizer_mask is not None:
            mask = mask & tokenizer_mask.bool().unsqueeze(1)
        return mask

class MSALayer(nn.Module):
    def __init__(self,
                 num_heads,
                 hidden_dim,
                 attn_dropout_p=0.1
                 ):
        
        assert hidden_dim % num_heads == 0
        
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.head_dim = self.hidden_dim // self.num_heads
        self.attn_dropout_p = attn_dropout_p

        self.toq = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.tok = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.tov = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.ffc = nn.Linear(self.hidden_dim, self.hidden_dim)

        self.layer_norm = nn.LayerNorm(normalized_shape=self.hidden_dim)
        self.attn_dropout = nn.Dropout(p=self.attn_dropout_p if self.attn_dropout_p else 0)

    def forward(self,
                x,
                mask=None):
        # shape of input is [b_size, seq_len, hidden_dim]
        q = self.toq(x)
        k = self.tok(x)
        v = self.tov(x)
        
        q = rearrange(q, 'b s (num_heads h) -> (b num_heads) s h', num_heads=self.num_heads)
        k = rearrange(k, 'b s (num_heads h) -> (b num_heads) s h', num_heads=self.num_heads)
        v = rearrange(v, 'b s (num_heads h) -> (b num_heads) s h', num_heads=self.num_heads)

        output, probs = attn_function(q, k, v, mask=mask, attn_dropout=self.attn_dropout)

        output = rearrange(output, '(b num_heads) s h -> b s (num_heads h)', num_heads=self.num_heads)
        output = self.ffc(output)

        output = self.layer_norm(output + x)
        return output, probs
        
        
class DecoderLayer(nn.Module):
    def __init__(self, 
                 num_heads,
                 hidden_dim,
                 ffc_hidden_dim,
                 attn_dropout_p=0.1,
                 ffc_dropout_p=0.1,
                 ):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.ffc_hidden_dim = ffc_hidden_dim
        self.attn_dropout_p = attn_dropout_p
        self.ffc_dropout_p = ffc_dropout_p

        self.ffc_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, self.ffc_hidden_dim),
            nn.GELU(),
            nn.Dropout(p=self.ffc_dropout_p),
            nn.Linear(self.ffc_hidden_dim, self.hidden_dim)
        ) 
        self.ffc_layer_norm = nn.LayerNorm(normalized_shape=self.hidden_dim)

        self.msalayer = MSALayer(self.num_heads,
                                 self.hidden_dim,
                                 self.attn_dropout_p,)

    def forward(self,
                x,
                mask=None):
        res = x
        x, _ = self.msalayer(x, mask=mask)
        
        return self.ffc_layer_norm(self.ffc_layer(x) + res)
    

def attn_function(q, k, v, mask=None, attn_dropout=None):
    
    #q, k, v shape is [b, s, h]
    b_size = q.shape[0]
    seq_len = q.shape[1]
    hidden_dim = q.shape[2]


    scaled_dot_product = einsum('bsh, bvh -> bsv', [q, k])/math.sqrt(hidden_dim)

    if mask:
        scaled_dot_product = scaled_dot_product.masked_fill(mask==False, 1e-9)
    
    if attn_dropout:
        scaled_dot_product = attn_dropout(scaled_dot_product)
    
    attn_probs = F.softmax(scaled_dot_product, dim=-1)
    attn_output = einsum('bsv, bvd -> bsd', [attn_probs, v])

    return attn_output, attn_probs

Writing model.py


In [79]:
%%writefile train_tokenizer.py
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing

if __name__ == "__main__":
    tokenizer = Tokenizer(BPE(unk_token="<unk>"))
    trainer = BpeTrainer(special_tokens=["<unk>",
                                         "<s>",
                                         "<pad>",
                                         "<bos>",
                                         ], vocab_size=5000) #i took 5k just randomly
    tokenizer.pre_tokenizer = Whitespace()

    files = ["data/input.txt"]
    tokenizer.train(files, trainer)

    tokenizer.post_processor = TemplateProcessing(
        single="<bos> $A <s>",
        special_tokens=[
            ("<s>", tokenizer.token_to_id("<s>")),
            ("<bos>", tokenizer.token_to_id("<bos>")),
            ("<pad>", tokenizer.token_to_id("<pad>"))
        ],
    )
    tokenizer.enable_padding(pad_id=2, pad_token="<pad>")
    tokenizer.save("data/tokenizer.json")

Overwriting train_tokenizer.py


In [3]:
%%writefile objects.py

import numpy as np
import torch

from torch.utils.data import Dataset
from tokenizers import Tokenizer
from typing import List


class ConstantLenghtDataset(Dataset):
    def __init__(self, 
                 texts: List[str],
                 tokenizer: Tokenizer,
                 length: int=512,):
        self.texts = texts
        self.length = length
        self.tokenizer = tokenizer
        self.tokenizer.no_padding()

        encoded_text = tokenizer.encode_batch(self.texts)
        tokens_num = [len(s.tokens) for s in encoded_text]
        constant_len_dataset_ids = []
        concat_sentences_ids = []
        sum=0
        
        for idx, num in enumerate(tokens_num):
            if sum > 512:
                constant_len_dataset_ids.append(concat_sentences_ids)
                concat_sentences_ids = []
                sum = 0

            concat_sentences_ids.append(idx)
            sum+=num
        
        np_text = np.array(self.texts)
        new_dataset = []
        for idxs in constant_len_dataset_ids:
            new_dataset.append(' '.join(np_text[idxs].tolist()))

        self.dataset = new_dataset

    def __len__(self,):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
    
class TokenizerWrapper():
    def __init__(self,
                 tokenizer,
                 pad_seq_len=512):
        self.tokenizer = tokenizer
        self.pad_seq_len = pad_seq_len
        self.tokenizer.enable_padding(pad_id=2, pad_token="<pad>", length=pad_seq_len)
        self.vocab_size = self.tokenizer.get_vocab_size()

    def __call__(self, input_sentences: List[str], batch=True):
        output = {}
        if batch:
            encoded_input = self.tokenizer.encode_batch(input_sentences)
            ids = torch.tensor([input.ids for input in encoded_input], requires_grad=False)
            attn_masks = torch.tensor([input.attention_mask for input in encoded_input], requires_grad=False)
        else:
            encoded_input = self.tokenizer.encode(input_sentences)
            ids = torch.tensor(encoded_input.ids, requires_grad=False).unsqueeze(0)
            attn_masks = torch.tensor(encoded_input.attention_mask, requires_grad=False).unsqueeze(0)
            
        output['input_ids'] = ids
        output['attn_mask'] = attn_masks

        return output

Writing objects.py


In [80]:
%%writefile utils.py
import os

import torch
import torch.distributed as dist
import torch.nn as nn

from tqdm import tqdm

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8080'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    dist.barrier()
    
def cleanup():
    dist.destroy_process_group()
    
    
def train(epoch, model, optimizer, train_dataloader, tokenizerwrapped, scaler):
    model.train()
    training_loss = 0
    tokenizerwrapped.tokenizer.enable_padding(pad_id=2, pad_token="<pad>", length=tokenizerwrapped.pad_seq_len)
    for batch_num, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        inputs = tokenizerwrapped(batch)
        labels = inputs['input_ids'].cuda(non_blocking=True)
        attn_mask = inputs['attn_mask'].cuda(non_blocking=True)
        
        with torch.cuda.amp.autocast(enabled=True):
            logits = model(labels, attn_mask)
        
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
        
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))        
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        training_loss += loss.item()

    training_loss /= batch_num
    print(f"Epoch: {epoch}, Training loss: {training_loss}")

def test(epoch, model, test_dataloader, tokenizerwrapped):
    model.eval()
    test_loss = 0
    tokenizerwrapped.tokenizer.enable_padding(pad_id=2, pad_token="<pad>", length=tokenizerwrapped.pad_seq_len)
    with tqdm(total=len(test_dataloader.dataset)) as progress_bar:
        with torch.no_grad():
            for batch_idx, batch in enumerate(test_dataloader):
                inputs = tokenizerwrapped(batch)
                labels = inputs['input_ids'].cuda(non_blocking=True)
                attn_mask = inputs['attn_mask'].cuda(non_blocking=True)

                logits = model(labels, attn_mask)
                
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
        
                loss_fn = nn.CrossEntropyLoss()
                loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                
                test_loss += loss.item()
                progress_bar.update(labels.size(0))
            
            test_loss /= batch_idx
    
    return test_loss


def prepare_data():
    with open('data/input.txt', 'r') as f:
        input_text = f.readlines()
    input_text = [i for i in input_text if i!='\n']
    train_size = 0.9
    train_ids = int(len(input_text) * train_size)
    train_data = input_text[: train_ids]
    test_data = input_text[train_ids:]
    
    return train_data, test_data

Overwriting utils.py


In [92]:
%%writefile main.py
import os
import sys
from time import time_ns

import numpy as np
import random
import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn


from tokenizers import Tokenizer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader


from model import GPT
from state import load_checkpoint, save_checkpoint
from objects import ConstantLenghtDataset, TokenizerWrapper
from utils import (setup, cleanup, train, test, prepare_data)



tokenizer = Tokenizer.from_file("data/tokenizer.json")
model_config = dict(
    num_layers=12,
    num_heads=12,
    hidden_dim=768,
    ffc_hidden_dim=3072,
    max_seq_len=512,
    vocab_size=tokenizer.get_vocab_size()
)

tokenizerwrapped = TokenizerWrapper(tokenizer, pad_seq_len=model_config['max_seq_len'])





train_texts, test_texts = prepare_data()

train_dataset = ConstantLenghtDataset(train_texts, tokenizer, length=model_config['max_seq_len'])
test_dataset = ConstantLenghtDataset(test_texts, tokenizer, length=model_config['max_seq_len'])

NUM_EPOCHS = 200
WORLD_SIZE = 2
BATCH_SIZE = 32
LR = 2e-5
SAVE_INTERVAL=10
SAVE_PATH = "checkpoints/model.pt"




def demo_basic(rank, world_size):
    print(f"Running basic GPT-1 traning on device: {rank}.")
    setup(rank, world_size)
    
    torch.cuda.set_device(rank)
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
                              shuffle=False, num_workers=2, pin_memory=True, sampler=train_sampler)
    
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                            shuffle=False, num_workers=2, pin_memory=True)

    
    
    model = GPT(**model_config).cuda(rank)
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, nesterov=True)
    model = DDP(model, device_ids=[rank])
    state = load_checkpoint(SAVE_PATH, rank, model, optimizer)

    scaler = torch.cuda.amp.GradScaler(enabled=True)

    cudnn.benchmark = True
    
    for epoch in range(NUM_EPOCHS):
        t0 = time_ns()

        train(epoch, model, optimizer, train_dataloader, tokenizerwrapped, scaler)

        t1 = time_ns()
        delta = (t1 - t0) / (10 ** 9)
        print(f"Device {rank} - Train time: {delta} sec")
        
        if rank == 0:
            loss = test(epoch, model, test_dataloader, tokenizerwrapped)
            print(f"Loss: {loss}%")

        if epoch in [int(NUM_EPOCHS * 0.5), int(NUM_EPOCHS * 0.75)]:
            optimizer.param_groups[0]['lr'] /= 10.
            
        if epoch % SAVE_INTERVAL == 0 and rank == 0:
            save_checkpoint(state, SAVE_PATH)

    state.epoch = epoch
    cleanup()
    
    
def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)
    
    
if __name__ == '__main__':
    run_demo(demo_basic,
             2)

Overwriting main.py


In [93]:
%%writefile state.py
import torch
import os

class State:
    """
    Container for objects that we want to checkpoint. Represents the
    current "state" of the worker. This object is mutable.
    """

    def __init__(self, model, optimizer):
        self.epoch = -1
        self.model = model
        self.optimizer = optimizer


    def capture_snapshot(self):
        """
        Essentially a ``serialize()`` function, returns the state as an
        object compatible with ``torch.save()``. The following should work
        ::
        snapshot = state_0.capture_snapshot()
        state_1.apply_snapshot(snapshot)
        assert state_0 == state_1
        """
        return {
            "epoch": self.epoch,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }

    def apply_snapshot(self, obj, device_id):
        """
        The complimentary function of ``capture_snapshot()``. Applies the
        snapshot object that was returned by ``capture_snapshot()``.
        This function mutates this state object.
        """

        self.epoch = obj["epoch"]
        self.model.load_state_dict(obj["model"])
        self.optimizer.load_state_dict(obj["optimizer"])

    def save(self, f):
        torch.save(self.capture_snapshot(), f, _use_new_zipfile_serialization=False)

    def load(self, f, device_id):
        # Map model to be loaded to specified single gpu.
        snapshot = torch.load(f, map_location=f"{device_id}")
        self.apply_snapshot(snapshot, device_id)


def save_checkpoint(state: State, filename):
    checkpoint_dir = os.path.dirname(filename)
    os.makedirs(checkpoint_dir, exist_ok=True)
    state.save(filename)
    print(f"=> saved checkpoint for epoch {state.epoch} at {filename}")


def load_checkpoint(checkpoint_file, device_id, model, optimizer) -> State:
    state = State(model, optimizer)

    if os.path.isfile(checkpoint_file):
        print(f"=> loading checkpoint file: {checkpoint_file}")
        state.load(checkpoint_file, device_id)
        print(f"=> loaded checkpoint file: {checkpoint_file}")
    return state


Overwriting state.py


# Code execution part

In [32]:
!pip install tokenizers
!pip install einops
!pip install torch



In [8]:
%%writefile get_data.sh
mkdir data
cd data
wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt


Writing get_data.sh


In [9]:
!sh get_data.sh

--2023-07-21 11:29:34--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1,1M) [text/plain]
Saving to: ‘input.txt’


2023-07-21 11:29:35 (1,09 MB/s) - ‘input.txt’ saved [1115394/1115394]

get_data.sh: 4: !sh: not found


In [94]:
!python train_tokenizer.py

[00:00:00] Pre-processing files (1 Mo)              ░░░░░░░░                  0%


[2K[1B[1A[00:00:00] Pre-processing files (1 Mo)              ████████                100%
[00:00:00] Tokenize words                           ████████ 0        /        0
[2K[1B[1A[00:00:00] Tokenize words                           ████████ 13355    /    13355

[2K[1B[1A[00:00:00] Count pairs                              ████████ 13355    /    13355

[2K[1B[1A[00:00:00] Compute merges                           ░░░░░░░░ 550      /     5000
[2K[1B[1A[00:00:00] Compute merges                           ████░░░░ 2700     /     5000
[2K[1B[1A[00:00:00] Compute merges                           ████████ 4933     /     4933



In [95]:
!python main.py

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Running basic GPT-1 traning on device: 1.
Running basic GPT-1 traning on device: 0.
Epoch: 0, Training loss: 9.643802960713705
Device 1 - Train time: 5

In [104]:
%%writefile inference.py
import torch
import torch.multiprocessing as mp

from model import GPT
from utils import setup
from state import load_checkpoint, save_checkpoint
from torch.nn.parallel import DistributedDataParallel as DDP
from tokenizers import Tokenizer
from objects import TokenizerWrapper

import os

def example(rank, world_size):

    setup(rank, world_size)
    tokenizer = Tokenizer.from_file("data/tokenizer.json")
    prefix = "<bos> A thou"


    model_config = dict(
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        ffc_hidden_dim=3072,
        max_seq_len=512,
        vocab_size=tokenizer.get_vocab_size()
    )

    parallel_model = DDP(GPT(**model_config).to(rank))
    parallel_model.load_state_dict(torch.load("checkpoints/model.pt", map_location={str(rank): 'cuda:0'})['model']
    )
    parallel_model.eval()

    tokenizerwrapped = TokenizerWrapper(tokenizer, 0)
    batch = tokenizerwrapped(prefix, batch=False)

    
    num_generations = 400
    with torch.cuda.amp.autocast():
        for i in range(num_generations):
            attn_mask = batch['attn_mask']
            curr_num_tokens = batch['input_ids'].shape[-1]
            outputs = parallel_model(batch['input_ids'].cuda(rank), attn_mask.cuda(rank))
            probs = outputs[0, -1].div(0.8).softmax(-1)
            token = torch.multinomial(probs, 1).view([])

            print(tokenizerwrapped.tokenizer.decode([token]), end=' ', flush=True)
            batch = dict(input_ids=outputs[0, -1].argmax(-1).reshape(1, 1),
                         attn_mask=torch.ones(1, curr_num_tokens+1, requires_grad=False).cuda(rank))

        
    
if __name__ == "__main__":

    world_size = 1
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)




Overwriting inference.py


In [105]:
!python inference.py
#inference is bad. I think its due to data...

Fie                                                                                                                                                                                                                                                                                                                                                                                                                