In [None]:
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW
from prettytable import PrettyTable
from torch.cuda.amp import GradScaler, autocast
import math
from datasets import load_dataset

model_name = "gpt2-medium"

def __device__() -> torch.device:
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = __device__()
print(f"Using device: {device}")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = "<PAD>"

path_to_files = {
    'train': 'data/train.csv',
    'test': 'data/test.csv'
}

dataset = load_dataset('csv', data_files=path_to_files)

train_dataset = dataset['train']['Description']
test_dataset = dataset['test']['Description']

print(f"{len(train_dataset)} Entries in train dataset.")
print("\nExamples: ")
for i in range(10):
    print(train_dataset[i])

print(f"\n{len(test_dataset)} Entries in test dataset.")
print("\nExamples: ")
for i in range(10):
    print(test_dataset[i])

In [None]:
class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

train_dataset = TextDataset(train_dataset)
test_dataset = TextDataset(test_dataset)

In [None]:
def print_trainable_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    all_params = sum(p.numel() for p in model.parameters())
    table = PrettyTable(["Parameter Type", "Count"])
    table.add_row(["Trainable Params", trainable_params])
    table.add_row(["All Params", all_params])
    table.add_row(["Trainable %", f"{100 * trainable_params / all_params:.2f}"])
    print(table)

In [None]:
def fine_tune(model, dataset, epochs=1, batch_size=8, learning_rate=1e-5, grad_accum_steps=4):
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    device = __device__()
    scaler = GradScaler() if device.type == 'cuda' else None
    model.train()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        print(f"EPOCH: {epoch + 1}/{epochs}")
        progress_bar = tqdm(enumerate(loader), total=len(loader), desc=f"Epoch {epoch + 1}/{epochs}")
        running_loss = 0.0

        for idx, batch in progress_bar:
            optimizer.zero_grad()
            batch = [str(text) for text in batch]
            inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt")
            input_ids = inputs['input_ids'].to(device)
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss / grad_accum_steps
            running_loss += loss.item()

            if scaler is not None:
                scaler.scale(loss).backward()
                if (idx + 1) % grad_accum_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                loss.backward()
                if (idx + 1) % grad_accum_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            avg_loss = running_loss / (idx + 1)
            progress_bar.set_postfix({
                'Loss': f"{loss.item() * grad_accum_steps:.4f}",
                'Avg Loss': f"{avg_loss:.4f}"
            })

        print(f"Epoch {epoch + 1} completed with average loss: {avg_loss:.4f}")

In [None]:
full_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
print_trainable_parameters(full_model)

In [None]:
# ! RUN THIS LINE WILL COST RESET COLAB SESSION DUE TO OVERLOAD MEMORY
# fine_tune(full_model, train_dataset, epochs=1)

In [None]:
class LoRALinear(nn.Module):
    def __init__(self, weight, bias, lora_dim):
        super(LoRALinear, self).__init__()
        out, inp = weight.shape
        self.linear = nn.Linear(inp, out, bias=bias is not None)
        self.linear.weight = nn.Parameter(weight)
        if bias is not None:
            self.linear.bias = nn.Parameter(bias)
        self.lora_right = nn.Parameter(torch.zeros(inp, lora_dim))
        nn.init.kaiming_uniform_(self.lora_right, a=math.sqrt(5))
        self.lora_left = nn.Parameter(torch.zeros(lora_dim, out))

    def forward(self, input):
        frozen_output = self.linear(input)
        lora_output = input @ self.lora_right.to(input.device) @ self.lora_left.to(input.device)
        return frozen_output + lora_output

In [None]:
def integrate_lora(model, lora_dim=8):
    targets = [n for n, _ in model.named_modules() if "attn.c_attn" in n]
    for name in targets:
        name_struct = name.split(".")
        module_list = [model]
        for struct in name_struct:
            module_list.append(getattr(module_list[-1], struct))
        lora_layer = LoRALinear(
            weight=torch.transpose(module_list[-1].weight, 0, 1).to(device),
            bias=module_list[-1].bias.to(device) if module_list[-1].bias is not None else None,
            lora_dim=lora_dim
        ).to(device)
        module_list[-2].__setattr__(name_struct[-1], lora_layer)

In [None]:
def freeze_non_lora_params(model):
    for n, p in model.named_parameters():
        p.requires_grad = "lora_right" in n or "lora_left" in n

In [None]:
try:
    lora_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    integrate_lora(lora_model, lora_dim=8)
    freeze_non_lora_params(lora_model)
    print_trainable_parameters(lora_model)
    fine_tune(lora_model, train_dataset, epochs=1)
except Exception as e:
    print(f"An error occurred: {e}")

In [None]:
def choose_from_top(probs, n=1):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind] / np.sum(probs[ind])
    choice = np.random.choice(n, 1, p=top_prob)
    return int(ind[choice][0])

def generate_example(model, query, max_length=100, top_n=1):
    cur_ids = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
    for _ in range(max_length):
        outputs = model(cur_ids)
        logits = outputs.logits
        softmax_logits = torch.softmax(logits[0, -1], dim=0)
        next_token_id = choose_from_top(softmax_logits.to('cpu').detach().numpy(), n=top_n)
        cur_ids = torch.cat([cur_ids, torch.ones((1, 1)).long().to(device) * next_token_id], dim=1)
        if next_token_id in tokenizer.encode(''):
            break
        print(tokenizer.decode([next_token_id]), end='')

def evaluate_model(model):
    query = "TITLE: Important new from Paris! DESCRIPTION:"
    print(f"Model Input: {query}\n")
    print("Model Completion: ")
    generate_example(model, query)

In [None]:
lora_model.eval()
evaluate_model(lora_model)