In [1]:
import torch
from torch import nn
from pathlib import Path
from tokenizers import Tokenizer
from huggingface_hub import PyTorchModelHubMixin

In [2]:
import torch

# Check if CUDA is available
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

# If CUDA is available, print the GPU name and perform a test operation
if cuda_available:
    # Get the name of the GPU
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {gpu_name}")

    # Create a tensor and move it to the GPU
    x = torch.tensor([1.0, 2.0, 3.0], device='cuda')
    print(f"Tensor on GPU: {x}")

    # Perform a simple operation
    y = x * 2
    print(f"Result of operation on GPU: {y}")
else:
    print("CUDA is not available. Please check your PyTorch installation and GPU drivers.")

CUDA available: True
GPU Name: NVIDIA A100-SXM4-40GB
Tensor on GPU: tensor([1., 2., 3.], device='cuda:0')
Result of operation on GPU: tensor([2., 4., 6.], device='cuda:0')


In [3]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7935a6682530>

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

In [5]:
!pip install torchinfo



In [6]:

#Data

In [7]:
#Collab setup

data_path = Path('data')
data_path.mkdir(exist_ok=True)
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!cp input.txt data/input.txt


--2024-12-29 18:13:48--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-12-29 18:13:48 (161 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [8]:

#Datasets

# Using tinyshakespeare

with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

####################################################################

#Using BookCorpus

# from datasets import load_dataset
# data = load_dataset('bookcorpus/bookcorpus')

In [9]:

# Extracting the content of  the Dataset
# Open a file for writing
# with open('bookcorpus_text.txt', 'w', encoding='utf-8') as f:
#     # Traverse the dataset and write text data to the file
#     for record in data['train']['text']:
#         f.write(record)

# print("Writing to file complete.")

# Read the file contents into a single string
with open('data/input.txt', 'r', encoding='utf-8') as f:
    concatenated_text = f.read()

# print("Reading from file and concatenation complete.")
# print(concatenated_text[:225000000])  # Print the first 1000 characters
# print(f"Total characters: {len(concatenated_text)}")
# print("Total words: ", len(concatenated_text.split()))

#Using only 1% of the total characters (225 million out of 4.2 billion ->Total words:  45756831 )
# concatenated_text = concatenated_text[:225000000]
print("Total words: ", len(concatenated_text.split()))


Total words:  202651


In [10]:


###############################################################################

#Subword level tokenization

#Loading custom trained BPE
# Load the tokenizer
tokenizer = Tokenizer.from_file("bpe_tokenizer_tinyshakespeare_1k.json")
vocab_size = tokenizer.get_vocab_size()
# Encode and decode functions
encode = lambda s: tokenizer.encode(s).ids
decode = lambda l: tokenizer.decode(l)

###############################################################################

#Character level tokenization

# here are all the unique characters that occur in this text
# chars = sorted(list(set(text)))
# vocab_size = len(chars)


# create a mapping from characters to integers
# stoi = { ch: i for i,ch in enumerate(chars) }
# itos = { i:ch for i,ch in enumerate(chars) }
# encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
# decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


In [22]:
#Hyperparameters

block_size = 256
batch_size = 64
embeddings_dims = 768
attn_dropout = 0.1
no_of_heads = 12 #IMP needs to be thoroughly calculated
dropout = 0.1
init_lambda = 0.8
epochs = 100
max_lr = 3e-4
no_of_decoder_layers = 12 #IMP needs to be thoroughly calculated
attn_dropout = 0.1
weight_decay_optim = 0.01

In [23]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [24]:
len(data)
# steps_per_epoch = 383747 / 64 = 6000
#Tptal epochs = epcoh * steps_per_epoch = 1 * 6000 = 6000

383747

In [25]:
# Text embeddings
class TextEmbeddings(nn.Module):
    def __init__(
        self,
        vocab_size = vocab_size,
        embeddings_dims = embeddings_dims
    ):
        super().__init__()
        self.embeddings_table = nn.Embedding(num_embeddings = vocab_size, embedding_dim=embeddings_dims, device=device) #Just a look up table to convert the toekns_ids to some numbers
        # nn.init.normal_(self.embeddings_table.weight.data, mean=0, std=0.02)

    def forward(self, x):
        return self.embeddings_table(x)

In [26]:
#Position embeddings
class PositionEmbeddings(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims,
        block_size = block_size
    ):
        super().__init__()

        self.position_embeddings = nn.Parameter(torch.randn(1, block_size, embeddings_dims, device=device), requires_grad=True) #To give positional embeddings to each token of the input text, hence num_embeddings=block_size
        # nn.init.normal_(self.position_embeddings.weight.data, mean=0, std=0.02)

    def forward(self):
        return self.position_embeddings

In [27]:

#Layer Normalization

class RMSNormalization(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims
    ):
        super().__init__()

        self.layer_norm = nn.RMSNorm(normalized_shape=embeddings_dims, eps=1e-5, elementwise_affine=True)

    def forward(self, x):
        return self.layer_norm(x)

In [28]:
class Swish(nn.Module):
    def __init__(
        self,
        block_size: int = block_size,
        embeddings_dims: int = embeddings_dims
    ):
        super().__init__()

        self.sig = torch.nn.Sigmoid()


    def forward(self, x):
        swish = x * self.sig(x)

        return swish


In [29]:
class SWiGLU(nn.Module):
    def __init__(
        self,
        block_size: int = block_size,
        embeddings_dims: int = embeddings_dims
    ):
        super().__init__()

        self.swish = Swish(block_size=block_size, embeddings_dims=embeddings_dims)
        self.linear_layer1 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False)
        self.linear_layer2 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False)
        self.linear_layer3 = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False)


    def forward(self, x):
        swish_res = self.swish(self.linear_layer1(x))
        x_V = self.linear_layer2(x)
        res = torch.mul(swish_res, x_V)
        out = self.linear_layer3(res)
        return out


In [30]:
#FeedForward Neural Network

class MLPBlock(nn.Module):
    def __init__(
        self,
        dropout = dropout,
        embeddings_size = embeddings_dims,
        # inner_dimensional_states: int = 3072
    ):
        super().__init__()
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device)
        self.swiglue = SWiGLU(block_size=block_size, embeddings_dims=embeddings_dims)
        self.dropout = nn.Dropout(p = dropout)

    def forward(self, x):
        x = self.swiglue(x)
        x = self.linear_layer(x)
        x = self.dropout(x)
        return x

In [31]:
# #Weights Initilization (for MLP Block)
# def weights_init(m):
#     classname = m.__class__.__name__
#     if classname.find('Linear') != -1:
#         nn.init.normal_(m.weight.data, 0.0, 0.02)  #mean = 0, std = 0.02



In [32]:

class AttentionHead(nn.Module):
    def __init__(
        self,
        lambda_init = init_lambda,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,

    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device,bias=False)
        self.dropout = nn.Dropout(p = attn_dropout)
        self.norm = RMSNormalization(embeddings_dims=self.head_size)
        self.lambda_init = init_lambda

        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_size, device=device).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_size, device=device).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_size, device=device).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_size, device=device).normal_(mean=0,std=0.1))
        # self.lambda_final = torch.exp(self.lambda_q1 * self.lambda_k1) - torch.exp(self.lambda_q2 * self.lambda_k2) + lambda_init

    def split(self, tensor: torch.tensor):

        split_1 = tensor[:, :, :tensor.shape[-1] // 2 ]
        split_2 = tensor[:, :, tensor.shape[-1] // 2:]
        # print("split1: ", split_1.shape)
        # print("split2: ", split_2.shape)
        return (split_1, split_2)

    def forward(self, x):
        batch, block_size, embd_dims = x.shape
        k = self.keys(x)
        q = self.query(x)
        v = self.values(x)

        k_1, k_2 = self.split(k)
        q_1, q_2 = self.split(q)

        self.lambda_final = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1)) - torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1)) + self.lambda_init

        # print("K2: " , k_2.shape)
        # print("q2: ", q_2.shape)
        masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        weights_1 = q_1 @ torch.transpose(k_1, dim0=-2, dim1=-1) * (k_1.shape[-1] ** -0.5)
        weights_2 = q_2 @ torch.transpose(k_2, dim0=-2, dim1=-1) * (k_2.shape[-1] ** -0.5)
        # print("Weights 1: ", weights_1.shape)
        # print("Weights 2: ", weights_2.shape)
        masked_values_1 = weights_1.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
        masked_values_2 = weights_2.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
        weights_normalized_1 = nn.functional.softmax(masked_values_1, dim=-1) #Normalize along the embeddings dimension for all the tokens
        weights_normalized_2 = nn.functional.softmax(masked_values_2, dim=-1)
        # weights_normalized = self.dropout(weights_normalized)
        # print("Weights: ", weights_normalized_1.shape)
        # print("weights norm 1: ", (self.lambda_final * weights_normalized_2).shape)
        # print("Lambda final: ", self.lambda_final)
        out = (weights_normalized_1 - self.lambda_final * weights_normalized_2 )@ v
        out = self.norm(out) * (1-init_lambda)
        out = self.dropout(out)
        return out



In [33]:

class MHA(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        lambda_init = init_lambda
    ):
        super().__init__()
        self.norm = RMSNormalization(embeddings_dims=embeddings_dims)
        self.heads = nn.ModuleList([AttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False)
        # self.lambda_q1 = nn.Parameter(torch.randn(1, block_size, block_size, device=device))
        # self.lambda_q2 = nn.Parameter(torch.randn(1, block_size, block_size, device=device))
        # self.lambda_k1 = nn.Parameter(torch.randn(1, block_size, block_size, device=device))
        # self.lambda_k2 = nn.Parameter(torch.randn(1, block_size, block_size, device=device))


    def forward(self, x):
        # self.lambda_final = (torch.exp(self.lambda_q1 * self.lambda_k1) - torch.exp(self.lambda_q2 * self.lambda_k2)) + lambda_init
        concat = self.norm(torch.cat([head(x) for head in self.heads], dim=-1))
        concat_norm = self.norm(concat)
        linear_layer = self.linear(concat_norm)
        out = self.dropout(linear_layer)
        return out

In [34]:


# Decoder Block

class TransformerDecoderBlock(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        dropout = dropout,
        vocab_size = vocab_size
    ):
        super().__init__()

        self.mha = MHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.norm1 = RMSNormalization(embeddings_dims=embeddings_dims)
        self.norm2 = RMSNormalization(embeddings_dims=embeddings_dims)
        self.mlp_block = MLPBlock(dropout=dropout, embeddings_size=embeddings_dims)

    def forward(self, x):
        # print("Hiii:", x.shape)
        x = x + self.norm1(self.mha(x))
        # print("Hello: ", x.shape)
        x = x + self.norm2(self.mlp_block(x))

        return x

In [35]:

# Decoder Block

class DecoderModel(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        block_size = block_size,
        dropout = dropout,
        no_of_decoder_layers = no_of_decoder_layers,
        vocab_size = vocab_size
    ):
        super().__init__()

        self.positional_embeddings = nn.Parameter(torch.randn(1, block_size, embeddings_dims, device=device), requires_grad=True)
        # self.positional_embeddings = PositionEmbeddings(block_size=block_size, embeddings_dims=embeddings_dims) #To give positional embeddings to each token of the input text, hence num_embeddings=block_size
        # torch.nn.init.normal_(self.positional_embeddings, mean=0.0, std=0.02)
        self.text_embds = TextEmbeddings(vocab_size=vocab_size, embeddings_dims=embeddings_dims)
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=vocab_size, device=device, bias=False) # Takes in logits of dimensions- embeds_dims and converts it into dimension of vocab_size (logits in range of vocab_size)
        self.norm = RMSNormalization(embeddings_dims=embeddings_dims)
        self.decoder_layers = nn.Sequential(*[TransformerDecoderBlock(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, dropout=dropout, vocab_size=vocab_size) for _ in range(no_of_decoder_layers)])
        self.apply(self._init_weights)
        self.dropout = nn.Dropout(p = dropout)

    def _init_weights(self, module):  #Weight Initialization
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.1)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.1)

    def forward(self, x):
        # print(x.shape)
        x = self.text_embds(x)
        # print("AFTER: ", x.shape)
        x = x + self.positional_embeddings
        x = self.dropout(x)
        # print(x.shape)
        x = self.decoder_layers(x)
        x = self.norm(x)
        out = self.linear_layer(x)
        return out

In [36]:
#Instantiating the model
model = DecoderModel(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, block_size=block_size, dropout=dropout, no_of_decoder_layers=no_of_decoder_layers, vocab_size=vocab_size)
model = model.to(device)

In [37]:

#Printing a summary of the architecture
from torchinfo import summary
idx, targets = get_batch('test')
# print(idx.shape)
# idx = idx.to(device)
summary(model=model,
        input_data=idx,
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
DecoderModel (DecoderModel)                                  [64, 256]            [64, 256, 1000]      196,608              True
├─TextEmbeddings (text_embds)                                [64, 256]            [64, 256, 768]       --                   True
│    └─Embedding (embeddings_table)                          [64, 256]            [64, 256, 768]       768,000              True
├─Dropout (dropout)                                          [64, 256, 768]       [64, 256, 768]       --                   --
├─Sequential (decoder_layers)                                [64, 256, 768]       [64, 256, 768]       --                   True
│    └─TransformerDecoderBlock (0)                           [64, 256, 768]       [64, 256, 768]       --                   True
│    │    └─MHA (mha)                                        [64, 256, 768]       [64, 256, 76

In [38]:
 # Optimizer setup and scheduler steup

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
total_steps = 2000
eval_iters = 100
# warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=2000)
# lr_scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max= total_steps - initial_iters)
# lr_scheduler_linear = torch.optim.lr_scheduler.LinearLR(optimizer=optimizer, total_iters=initial_iters)

@torch.inference_mode()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            idx, targets = get_batch(split=split)
            logits = model(idx)
            batch_size, block_size, embeddings_dims = logits.shape
            logits = logits.view(batch_size*block_size, embeddings_dims) # Total tokens(words) => batch_size * block_size
            targets = targets.view(batch_size * block_size)
            loss = nn.functional.cross_entropy(logits, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [39]:
#Train the  model
from tqdm import tqdm

model.train()
for step in tqdm(range(total_steps)):

    # every once in a while evaluate the loss on train and val sets
    if (step  % eval_iters == 0 and step != 0) or step == total_steps - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        # torch.save(model.state_dict(), 'weights/_differential_transformer_86M_steps_%d.pth' % (step))

    idx, targets = get_batch(split='train')
    logits = model(idx)
    batch_size, block_size, embeddings_dims = logits.shape
    logits = logits.view(batch_size*block_size, embeddings_dims)
    targets = targets.view(batch_size * block_size)
    loss = nn.functional.cross_entropy(logits, targets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # print(loss.item())
    # break

    # if step != 0 and (step % eval_iters == 0 or step == total_steps -1) :
    #     loss_values = estimate_loss()
    #     print("Train Loss at {} steps : {}".format(step, loss.item()), "Val Loss at {} steps : {}".format(step, loss_values['val']))

  5%|▌         | 100/2000 [08:39<2:43:41,  5.17s/it]

step 100: train loss 8.3722, val loss 8.3632


 10%|█         | 200/2000 [18:01<2:34:35,  5.15s/it]

step 200: train loss 7.3941, val loss 7.4096


 15%|█▌        | 300/2000 [27:24<2:25:55,  5.15s/it]

step 300: train loss 6.8554, val loss 6.8805


 20%|██        | 400/2000 [36:46<2:17:19,  5.15s/it]

step 400: train loss 6.5609, val loss 6.5737


 25%|██▌       | 500/2000 [46:09<2:09:28,  5.18s/it]

step 500: train loss 6.3883, val loss 6.4120


 30%|███       | 600/2000 [55:35<2:00:42,  5.17s/it]

step 600: train loss 6.2789, val loss 6.3017


 35%|███▌      | 700/2000 [1:04:58<1:51:53,  5.16s/it]

step 700: train loss 6.2082, val loss 6.2322


 40%|████      | 800/2000 [1:14:20<1:43:58,  5.20s/it]

step 800: train loss 6.1570, val loss 6.1888


 45%|████▌     | 900/2000 [1:23:42<1:34:40,  5.16s/it]

step 900: train loss 6.1198, val loss 6.1452


 50%|█████     | 1000/2000 [1:33:07<1:26:19,  5.18s/it]

step 1000: train loss 6.0913, val loss 6.1139


 55%|█████▌    | 1100/2000 [1:42:30<1:17:27,  5.16s/it]

step 1100: train loss 6.0653, val loss 6.0859


 60%|██████    | 1200/2000 [1:51:51<1:08:52,  5.17s/it]

step 1200: train loss 6.0419, val loss 6.0662


 65%|██████▌   | 1300/2000 [2:01:14<1:01:27,  5.27s/it]

step 1300: train loss 6.0270, val loss 6.0487


 70%|███████   | 1400/2000 [2:10:38<52:08,  5.21s/it]

step 1400: train loss 6.0127, val loss 6.0341


 75%|███████▌  | 1500/2000 [2:20:04<43:21,  5.20s/it]

step 1500: train loss 6.0011, val loss 6.0283


 80%|████████  | 1600/2000 [2:29:28<34:41,  5.20s/it]

step 1600: train loss 5.9890, val loss 6.0091


 85%|████████▌ | 1700/2000 [2:38:51<25:56,  5.19s/it]

step 1700: train loss 5.9802, val loss 6.0121


 90%|█████████ | 1800/2000 [2:48:14<17:18,  5.19s/it]

step 1800: train loss 5.9713, val loss 5.9976


 95%|█████████▌| 1900/2000 [2:57:36<08:38,  5.19s/it]

step 1900: train loss 5.9654, val loss 5.9913


100%|█████████▉| 1999/2000 [3:06:54<00:05,  5.23s/it]

step 1999: train loss 5.9556, val loss 5.9870


100%|██████████| 2000/2000 [3:07:42<00:00,  5.63s/it]
