In [1]:
model_size = "gpt2-small (124M)" # Options: 'gpt2-small (124M)', 'gpt2-medium (355M)', 'gpt2-large (774M)', 'gpt2-xl (1558M)'
weights = "pretrained" # Options: 'pretrained' or 'random'
trainable_layers = "last_block" # Options: 'all', 'last_block', 'last_two_blocks', 'last_layer', 'lora', 'lora_alternative'
trainable_token_pos = "last" # Options: 'first', 'last'
average_embeddings = False
context_length = "longest_training_example" # Options: 'longest_training_example', 'model_context_length' or integer value.
lora_rank = 8 
lora_alpha = 8
no_padding = False
num_epochs = 5
batch_size = 8 
accumulation_steps = 1
disable_causal_mask = False
ignore_index = -100

In [2]:
import math
import os
from pathlib import Path
import time
import urllib.request
import zipfile

import pandas as pd
import tiktoken
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from gpt_download import download_and_load_gpt2
from utils import GPTModel, load_weights_into_gpt

In [3]:
class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
        torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x


class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)


# This LoRA code is equivalent to LinearWithLoRA
class LinearWithLoRAMerged(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        combined_weight = self.linear.weight + self.lora.alpha*lora.T
        return torch.nn.functional.linear(x, combined_weight, self.linear.bias)

In [4]:
class SentimentDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=None, pad_token_id=50256, no_padding=False):
        self.data = data
        self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)

        # Pre-tokenize texts
        self.encoded_texts = [
            tokenizer.encode(text)[:self.max_length]
            for text in self.data["text"]
        ]

        if not no_padding:
            # Pad sequences to the longest sequence
            self.encoded_texts = [
                et + [pad_token_id] * (self.max_length - len(et))
                for et in self.encoded_texts
            ]

    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["sentiment"]
        return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def _longest_encoded_length(self, tokenizer):
        max_length = 0
        for text in self.data["text"]:
            encoded_length = len(tokenizer.encode(text))
            if encoded_length > max_length:
                max_length = encoded_length
        return max_length

In [13]:
def instantiate_model(choose_model, load_weights):

    BASE_CONFIG = {
        "vocab_size": 50257,     # Vocabulary size
        "context_length": 1024,  # Context length
        "drop_rate": 0.0,        # Dropout rate
        "qkv_bias": True         # Query-key-value bias
    }

    model_configs = {
        "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
        "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
        "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
        "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
    }

    BASE_CONFIG.update(model_configs[choose_model])

    if not load_weights:
        torch.manual_seed(123)
    model = GPTModel(BASE_CONFIG)

    if load_weights:
        model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
        settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
        load_weights_into_gpt(model, params)

    model.eval()
    return model

In [6]:
def calc_loss_batch(input_batch, target_batch, model, device,
                    trainable_token_pos=-1, ignore_index=-100, average_embeddings=False):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)

    model_output = model(input_batch)
    if average_embeddings:
        # Average over the sequence dimension (dim=1)
        logits = model_output.mean(dim=1)
    else:
        # Select embeddings at the specified token position
        logits = model_output[:, trainable_token_pos, :]

    loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
    return loss


def calc_loss_loader(data_loader, model, device,
                     num_batches=None, trainable_token_pos=-1,
                     ignore_index=-100, average_embeddings=False):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(
                input_batch, target_batch, model, device,
                trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
                average_embeddings=average_embeddings
            )
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches


@torch.no_grad()  # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None,
                         trainable_token_pos=-1, average_embeddings=False):
    model.eval()
    correct_predictions, num_examples = 0, 0

    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            input_batch, target_batch = input_batch.to(device), target_batch.to(device)

            model_output = model(input_batch)
            if average_embeddings:
                # Average over the sequence dimension (dim=1)
                logits = model_output.mean(dim=1)
            else:
                # Select embeddings at the specified token position
                logits = model_output[:, trainable_token_pos, :]

            predicted_labels = torch.argmax(logits, dim=-1)

            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break
    return correct_predictions / num_examples


def evaluate_model(model, train_loader, val_loader, device,
                   eval_iter, trainable_token_pos=-1,
                   ignore_index=-100, average_embeddings=False):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(
            train_loader, model, device, num_batches=eval_iter,
            trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
            average_embeddings=average_embeddings
        )
        val_loss = calc_loss_loader(
            val_loader, model, device, num_batches=eval_iter,
            trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
            average_embeddings=average_embeddings
        )
    model.train()
    return train_loss, val_loss

In [25]:

def calc_loss_batch(input_batch, target_batch, model, device,
                    trainable_token_pos=-1, ignore_index=-100, average_embeddings=False):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)

    model_output = model(input_batch)
    if average_embeddings:
        # Average over the sequence dimension (dim=1)
        logits = model_output.mean(dim=1)
    else:
        # Select embeddings at the specified token position
        logits = model_output[:, trainable_token_pos, :]

    loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
    return loss


def calc_loss_loader(data_loader, model, device,
                     num_batches=None, trainable_token_pos=-1,
                     ignore_index=-100, average_embeddings=False):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(
                input_batch, target_batch, model, device,
                trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
                average_embeddings=average_embeddings
            )
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches


@torch.no_grad()  # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None,
                         trainable_token_pos=-1, average_embeddings=False):
    model.eval()
    correct_predictions, num_examples = 0, 0

    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            input_batch, target_batch = input_batch.to(device), target_batch.to(device)

            model_output = model(input_batch)
            if average_embeddings:
                # Average over the sequence dimension (dim=1)
                logits = model_output.mean(dim=1)
            else:
                # Select embeddings at the specified token position
                logits = model_output[:, trainable_token_pos, :]

            predicted_labels = torch.argmax(logits, dim=-1)

            num_examples += predicted_labels.shape[0]
            correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break
    return correct_predictions / num_examples


def evaluate_model(model, train_loader, val_loader, device,
                   eval_iter, trainable_token_pos=-1,
                   ignore_index=-100, average_embeddings=False):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(
            train_loader, model, device, num_batches=eval_iter,
            trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
            average_embeddings=average_embeddings
        )
        val_loss = calc_loss_loader(
            val_loader, model, device, num_batches=eval_iter,
            trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
            average_embeddings=average_embeddings
        )
    model.train()
    return train_loss, val_loss

def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                            eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
                            accumulation_steps=1, ignore_index=-100, average_embeddings=False):
    # Initialize lists to track losses and tokens seen
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    examples_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode

        for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
            loss = calc_loss_batch(
                input_batch, target_batch, model, device,
                trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
                average_embeddings=average_embeddings
            )

            # Use gradient accumulation if accumulation_steps > 1
            # See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html
            # for an explanation
            loss /= accumulation_steps

            loss.backward()  # Calculate loss gradients

            # Use gradient accumulation if accumulation_steps > 1
            is_update_step = ((batch_idx + 1) % accumulation_steps == 0) or ((batch_idx + 1) == len(train_loader))
            if is_update_step:
                optimizer.step()  # Update model weights using loss gradients
                optimizer.zero_grad()  # Reset loss gradients from previous batch iteration

            examples_seen += input_batch.shape[0]  # New: track examples instead of tokens
            global_step += 1

            # Optional evaluation step
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter,
                    trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
                    average_embeddings=average_embeddings
                )
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

            if max_steps is not None and global_step > max_steps:
                break

        # New: Calculate accuracy after each epoch
        train_accuracy = calc_accuracy_loader(
            train_loader, model, device, num_batches=eval_iter,
            trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
        )
        val_accuracy = calc_accuracy_loader(
            val_loader, model, device, num_batches=eval_iter,
            trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
        )
        print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
        print(f"Validation accuracy: {val_accuracy*100:.2f}%")
        train_accs.append(train_accuracy)
        val_accs.append(val_accuracy)

        if max_steps is not None and global_step > max_steps:
            break

    return train_losses, val_losses, train_accs, val_accs, examples_seen


In [8]:
def replace_linear_with_lora(model, rank, alpha, alternative=False):
    for name, module in model.named_children():
        if isinstance(module, torch.nn.Linear):
            # Replace the Linear layer with LinearWithLoRA
            if alternative:
                setattr(model, name, LinearWithLoRAMerged(module, rank, alpha))
            else:
                setattr(model, name, LinearWithLoRA(module, rank, alpha))
        else:
            # Recursively apply the same function to child modules
            replace_linear_with_lora(module, rank, alpha)

In [14]:
if trainable_token_pos == "first":
    trainable_token_pos = 0
elif trainable_token_pos == "last":
    trainable_token_pos = -1
else:
    print("Invalid --trainable_token_pos argument")

###############################
# Load model
###############################

if weights == "pretrained":
    load_weights = True
elif weights == "random":
    load_weights = False
else:
    print("Invalid --weights argument.")

model = instantiate_model(model_size, load_weights)
for param in model.parameters():
    param.requires_grad = False

if model_size == "gpt2-small (124M)":
    in_features = 768
elif model_size == "gpt2-medium (355M)":
    in_features = 1024
elif model_size == "gpt2-large (774M)":
    in_features = 1280
elif model_size == "gpt2-xl (1558M)":
    in_features = 1600
else:
    print("Invalid --model_size argument")

torch.manual_seed(123)
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)

if trainable_layers == "last_layer":
    pass
elif trainable_layers == "last_block" or trainable_layers == "last_two_blocks":
    for param in model.trf_blocks[-1].parameters():
        param.requires_grad = True
    for param in model.final_norm.parameters():
        param.requires_grad = True
    if trainable_layers == "last_two_blocks":
        for param in model.trf_blocks[-2].parameters():
            param.requires_grad = True
elif trainable_layers == "all":
    for param in model.parameters():
        param.requires_grad = True
elif trainable_layers in ("lora", "lora_alternative"):
    if trainable_layers == "lora_alternative":
        alternative = True
    else:
        alternative = False
    replace_linear_with_lora(model, rank=lora_rank, alpha=lora_alpha, alternative=alternative)
else:
    print("Invalid --trainable_layers argument.")

Invalid --trainable_token_pos argument
File already exists and is up-to-date: gpt2\124M\checkpoint
File already exists and is up-to-date: gpt2\124M\encoder.json
File already exists and is up-to-date: gpt2\124M\hparams.json
File already exists and is up-to-date: gpt2\124M\model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2\124M\model.ckpt.index
File already exists and is up-to-date: gpt2\124M\model.ckpt.meta
File already exists and is up-to-date: gpt2\124M\vocab.bpe


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_resid): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768,

In [17]:
import tiktoken
import pandas as pd
from sklearn.model_selection import train_test_split

# Load the dataset
dataset = pd.read_csv("tweet-sentiment-extraction/train.csv")
dataset = dataset.dropna(subset=["text", "sentiment"])

dataset = dataset[["text", "sentiment"]]
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

train_data = pd.DataFrame(train_data)
test_data = pd.DataFrame(test_data)

# Check the dataset sizes
print(f"Train size: {len(train_data)}, Test size: {len(test_data)}")

# Define the mapping
sentiment_mapping = {'neutral': 0, 'negative': 1, 'positive': 2}

# Map sentiments in both train and test datasets
train_data['sentiment'] = train_data['sentiment'].map(sentiment_mapping)
test_data['sentiment'] = test_data['sentiment'].map(sentiment_mapping)

tokenizer = tiktoken.get_encoding("gpt2")
print(tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}))

Train size: 21984, Test size: 5496
[50256]


In [18]:
train_dataset = SentimentDataset(
    data=train_data,
    max_length=None,
    tokenizer=tokenizer
)

print(train_dataset.max_length)

100


In [19]:
test_dataset = SentimentDataset(
    data=test_data,
    max_length=train_dataset.max_length,
    tokenizer=tokenizer
)

In [20]:
from torch.utils.data import DataLoader

num_workers = 0
batch_size = 8

torch.manual_seed(123)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=False,
)

In [21]:
print("Train loader:")
for input_batch, target_batch in train_loader:
    pass

print("Input batch dimensions:", input_batch.shape)
print("Label batch dimensions", target_batch.shape)

Train loader:
Input batch dimensions: torch.Size([8, 100])
Label batch dimensions torch.Size([8])


In [22]:
print(f"{len(train_loader)} training batches")
print(f"{len(test_loader)} test batches")

2748 training batches
687 test batches


In [23]:
train_dataset.max_length <= model.pos_emb.weight.shape[0], (
        f"Dataset length {train_dataset.max_length} exceeds model's context "
        f"length {model.pos_emb.weight.shape[0]}. Reinitialize data sets with "
        f"`max_length={model.pos_emb.weight.shape[0]}`"
    )

(True,
 "Dataset length 100 exceeds model's context length 1024. Reinitialize data sets with `max_length=1024`")

In [27]:
###############################
# Train model
###############################

start_time = time.time()
torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)

train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
    model, train_loader, test_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=50, eval_iter=5,
    max_steps=None, trainable_token_pos=trainable_token_pos,
    accumulation_steps=accumulation_steps, average_embeddings=average_embeddings
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

###############################
# Evaluate model
###############################

train_accuracy = calc_accuracy_loader(
    train_loader, model, device,
    trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
# val_accuracy = calc_accuracy_loader(
#     val_loader, model, device,
#     trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
# )
test_accuracy = calc_accuracy_loader(
    test_loader, model, device,
    trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
# print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
