In [47]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 512 # max number of words going into the model?

max_iters = 5000
eval_interval = 100
eval_iters = 200

learning_rate = 1e-3
device = 'cpu' if torch.cuda.is_available() else 'cpu'

n_embd = 64
sqrt_d = torch.sqrt(torch.tensor(n_embd)).int().item()
n_head = sqrt_d // 2
n_layer = 4
dropout = 0.0

In [48]:
class Feebler(nn.Module):
    ''' 
    T: Number of words going into the model
    C: Embedding dimension
    B: Batch size
    
    input: B, T, C
    output: B, T, sqrt(C)
    '''
    def __init__(self, sqrt_d):
        super().__init__()
        self.weights = nn.Parameter(
            torch.randn(sqrt_d, sqrt_d, block_size)
        )
        self.sqrt_d = sqrt_d

    def forward(self, data):
        # Data is of shape (b, n, d)
        data_reshaped = data.view(batch_size, n_embd, block_size)  # set up data for feebler
        data_reshaped = data.view(batch_size, self.sqrt_d, self.sqrt_d, block_size)  # reshape incoming data
        product = data_reshaped * self.weights  # multiply data with weights
        # perform columnwise sum inside each window
        updated_product = torch.sum(product, dim=2, keepdim=False)  # finally we have converted from dxn to sqrt(d)xn
        return updated_product.view(batch_size, block_size, self.sqrt_d)
    

class Booster(nn.Module):
    ''' 
    input: B, T, sqrt(C)
    output: B, T, C
    '''
    def __init__(self, sqrt_d):
        super(Booster, self).__init__()
        self.weights = nn.Parameter(
            torch.randn(sqrt_d, sqrt_d, block_size)
        )
        self.sqrt_d = sqrt_d

    def forward(self, attention_output):
        # attention_output is of shape (batch, n, sqrt_d)
        # set up data shape for the booster
        attention_output = attention_output.view(batch_size, self.sqrt_d, block_size)
        attention_output_reshaped = attention_output.view(batch_size, 1, -1) # flatten all rows into one row
        attention_output_reshaped = attention_output_reshaped.repeat(1, self.sqrt_d, 1)  # repeat each row sqrt_d times
        attention_output_reshaped = attention_output_reshaped.view(batch_size, self.sqrt_d, self.sqrt_d, block_size)
        # multiply
        revived_output = self.weights * attention_output_reshaped
        revived_output = revived_output.view(-1, block_size)
        return revived_output.view(batch_size, block_size, n_embd)

class QuickHead(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(sqrt_d, head_size, bias=False)
        self.query = nn.Linear(sqrt_d, head_size, bias=False)
        self.value = nn.Linear(sqrt_d, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x is of shape (batch_size, n, sqrt_d)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        v = self.value(x) # (B,T,C)

        collective_k = k.sum(1, keepdim=True)
        # Broadcast explicitly
        collective_k_bc = collective_k.repeat(1, block_size, 1)
        # q multiply k
        qk = q * collective_k_bc
        attention_weights = torch.softmax(qk, dim=1)
        collective_v = v.sum(dim=1, keepdim=True)
        collective_v_bc = collective_v.repeat(1, block_size, 1)
        output = collective_v_bc * attention_weights
        return output
    
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([QuickHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(sqrt_d, sqrt_d) # global variable sqrt_d
        self.dropout = nn.Dropout(dropout)  # global variable dropout

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
    
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, sqrt_d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(sqrt_d, 4 * sqrt_d),
            nn.ReLU(),
            nn.Linear(4 * sqrt_d, sqrt_d),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = sqrt_d // n_head
        self.feebler = Feebler(sqrt_d)
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(sqrt_d)
        self.ln1 = nn.LayerNorm(sqrt_d)
        self.ln2 = nn.LayerNorm(sqrt_d)
        self.booster = Booster(sqrt_d)

    def forward(self, x):
        x = self.feebler(x)
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        x = self.booster(x)
        return x

In [49]:
b = Block(n_embd, n_head)
b(torch.rand(batch_size, block_size, n_embd)).shape

torch.Size([16, 512, 64])

In [66]:
vocab_size = 100000

# super simple quickformer model
class Quickformer(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, 1)
        self.logits_maker = nn.Linear(block_size, 1)
        self.classifier = nn.Sigmoid()

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        x = self.lm_head(x) # (B,T,1)
        x = x.squeeze(2) # (B,T)
        logits = self.logits_maker(x) # (B,1)
        results = self.classifier(logits) # (B, 1)

        if targets is None:
            loss = None
        else:
            loss = F.binary_cross_entropy(results, targets)

        return results, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = Quickformer()
model = model.to(device)

In [51]:
inp = torch.rand(batch_size, block_size).long().to(device)
print(inp.shape)
l, ll = model(inp)
l.shape

torch.Size([16, 512])


torch.Size([16, 1])

# Data handling

In [52]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import torch

# Load IMDb dataset from Hugging Face
dataset = load_dataset("imdb")

# Use a pre-trained tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Define a custom PyTorch Dataset
class IMDbDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=block_size):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        text = item["text"]
        label = item["label"]

        # Tokenize and encode the text
        inputs = self.tokenizer(
            text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        input_ids = inputs["input_ids"].squeeze()
        attention_mask = inputs["attention_mask"].squeeze()

        return {"input_ids": input_ids, "attention_mask": attention_mask, "label": torch.tensor(label)}

# Create Train dataset (pytorch)
imdb_dataset_train = IMDbDataset(dataset["train"], tokenizer)
# Create Test dataset (pytorch)
imdb_dataset_test = IMDbDataset(dataset["test"], tokenizer)

# Create PyTorch DataLoader for train set
dataloader_train = DataLoader(imdb_dataset_train, batch_size=batch_size, shuffle=True)
# Create PyTorch DataLoader for test set
dataloader_test = DataLoader(imdb_dataset_test, batch_size=batch_size, shuffle=True)

In [56]:
# Inspect the dataset info to see the details
print(dataset["train"].info)  # shows 0 = neg, 1 = pos

print(dataset["train"][0].keys())

DatasetInfo(description='', citation='', homepage='', license='', features={'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['neg', 'pos'], id=None)}, post_processed=None, supervised_keys=None, task_templates=None, builder_name='parquet', dataset_name='imdb', config_name='plain_text', version=0.0.0, splits={'train': SplitInfo(name='train', num_bytes=33435948, num_examples=25000, shard_lengths=None, dataset_name='imdb'), 'test': SplitInfo(name='test', num_bytes=32653810, num_examples=25000, shard_lengths=None, dataset_name='imdb'), 'unsupervised': SplitInfo(name='unsupervised', num_bytes=67113044, num_examples=50000, shard_lengths=None, dataset_name='imdb')}, download_checksums={'hf://datasets/imdb@e6281661ce1c48d982bc483cf8a173c1bbeb5d31/plain_text/train-00000-of-00001.parquet': {'num_bytes': 20979968, 'checksum': None}, 'hf://datasets/imdb@e6281661ce1c48d982bc483cf8a173c1bbeb5d31/plain_text/test-00000-of-00001.parquet': {'num_bytes': 20470363, 'checksum': None}, 'hf:

In [67]:
# Testing if my model works well with this dataset
# Example of how to iterate through the dataloader
for batch in dataloader_train:
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["label"].unsqueeze(1).float()

    print('input ids:', input_ids.shape)
    print('labels:', labels.shape)

    print('loss:', model(input_ids, labels)[1])
    break

input ids: torch.Size([16, 512])
labels: torch.Size([16, 1])
loss: tensor(0.6649, grad_fn=<BinaryCrossEntropyBackward0>)


In [42]:
# THIS DOES NOT WORK
# # Testing if my model works well with this dataset
# # Example of how to iterate through the dataloader
# count = 0
# for idx, targets in dataloader_train:
#     idx = idx.to(device)
#     targets = targets.to(device)
#     print('idx:', idx.shape)
#     print('targets:', targets.shape)

#     print(model(idx, targets)[1])

#     count += 1

#     if count > 5:
#         break

In [18]:
input_ids = input_ids.to(device)
print(input_ids.shape)

torch.Size([64, 512])


In [35]:
model(input_ids)

(tensor([[0.5038],
         [0.5094],
         [0.3700],
         [0.5310],
         [0.5326],
         [0.4179],
         [0.4615],
         [0.4602],
         [0.4958],
         [0.5428],
         [0.4960],
         [0.4819],
         [0.4572],
         [0.4408],
         [0.5642],
         [0.5519],
         [0.3988],
         [0.4865],
         [0.4255],
         [0.4941],
         [0.5305],
         [0.4195],
         [0.4113],
         [0.4868],
         [0.3796],
         [0.4663],
         [0.4657],
         [0.5384],
         [0.5043],
         [0.4749],
         [0.5132],
         [0.5581],
         [0.5017],
         [0.4009],
         [0.3565],
         [0.5652],
         [0.4897],
         [0.4586],
         [0.4856],
         [0.4833],
         [0.4004],
         [0.5293],
         [0.5752],
         [0.5063],
         [0.4934],
         [0.4624],
         [0.4528],
         [0.4393],
         [0.4138],
         [0.5384],
         [0.4643],
         [0.5302],
         [0.

In [72]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_model(model, train_loader, test_loader, optimizer, loss_function, num_epochs=10, device='cuda'):
    model.to(device)  # Move the model to the specified device (GPU or CPU)

    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
        total_samples = 0

        # Use tqdm for progress bar
        train_loader = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}, Training')

        for batch in train_loader:
            try:
                idx = batch["input_ids"]
                targets = batch["label"].unsqueeze(1).float()
                idx, targets = idx.to(device), targets.to(device)

                optimizer.zero_grad()
                results, loss = model(idx)
                loss = loss_function(results, targets)
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * len(idx)
                total_samples += len(idx)

                # Update tqdm progress bar
                train_loader.set_postfix({'Train Loss': total_loss / total_samples})
            except:
                pass

        average_train_loss = total_loss / total_samples

        # Validation
        model.eval()
        with torch.no_grad():
            total_loss = 0
            total_samples = 0

            # Use tqdm for progress bar
            test_loader = tqdm(test_loader, desc=f'Epoch {epoch + 1}/{num_epochs}, Validation')

            for batch in test_loader:
                try:
                    idx = batch["input_ids"]
                    targets = batch["label"].unsqueeze(1).float()
                    idx, targets = idx.to(device), targets.to(device)

                    results, loss = model(idx)
                    loss = loss_function(results, targets)

                    total_loss += loss.item() * len(idx)
                    total_samples += len(idx)

                    # Update tqdm progress bar
                    test_loader.set_postfix({'Test Loss': total_loss / total_samples})
                except:
                    pass

        average_test_loss = total_loss / total_samples

        print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}, Test Loss: {average_test_loss:.4f}')

    print('Training complete!')

model = Quickformer()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_function = nn.BCELoss()

train_model(model, dataloader_train, dataloader_test, optimizer, loss_function, num_epochs=10, device=device)


Epoch 1/10, Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 1/10, Training: 100%|██████████| 1563/1563 [03:06<00:00,  8.39it/s, Train Loss=0.694]
Epoch 1/10, Validation: 100%|██████████| 1563/1563 [00:48<00:00, 31.94it/s, Test Loss=0.69] 


Epoch 1/10, Train Loss: 0.6943, Test Loss: 0.6900


Epoch 2/10, Training: 100%|██████████| 1563/1563 [02:54<00:00,  8.98it/s, Train Loss=0.667]
Epoch 2/10, Validation: 100%|██████████| 1563/1563 [00:44<00:00, 35.26it/s, Test Loss=0.675]


Epoch 2/10, Train Loss: 0.6671, Test Loss: 0.6745


Epoch 3/10, Training: 100%|██████████| 1563/1563 [02:35<00:00, 10.03it/s, Train Loss=0.554]
Epoch 3/10, Validation: 100%|██████████| 1563/1563 [00:45<00:00, 34.21it/s, Test Loss=0.631]


Epoch 3/10, Train Loss: 0.5541, Test Loss: 0.6306


Epoch 4/10, Training: 100%|██████████| 1563/1563 [02:46<00:00,  9.38it/s, Train Loss=0.409]
Epoch 4/10, Validation: 100%|██████████| 1563/1563 [00:49<00:00, 31.77it/s, Test Loss=0.668]


Epoch 4/10, Train Loss: 0.4095, Test Loss: 0.6678


Epoch 5/10, Training: 100%|██████████| 1563/1563 [03:15<00:00,  7.99it/s, Train Loss=0.286]
Epoch 5/10, Validation: 100%|██████████| 1563/1563 [00:54<00:00, 28.75it/s, Test Loss=0.712]


Epoch 5/10, Train Loss: 0.2862, Test Loss: 0.7120


Epoch 6/10, Training: 100%|██████████| 1563/1563 [03:09<00:00,  8.23it/s, Train Loss=0.183]
Epoch 6/10, Validation: 100%|██████████| 1563/1563 [00:57<00:00, 27.31it/s, Test Loss=0.811]


Epoch 6/10, Train Loss: 0.1828, Test Loss: 0.8111


Epoch 7/10, Training:  50%|████▉     | 781/1563 [01:35<01:32,  8.45it/s, Train Loss=0.105]

Epoch 7/10, Training:  50%|█████     | 782/1563 [01:35<01:36,  8.08it/s, Train Loss=0.105]