In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import re

save_directory = 'D:/gpt2-large'
tokenizer = GPT2Tokenizer.from_pretrained(save_directory)
base_model = GPT2LMHeadModel.from_pretrained(save_directory)

# Ensure da tokenizer forgets to throw error when encoding long body sequences
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
class NotesDataset(Dataset):
    def __init__(self, file_path, tokenizer, block_size=512):
        with open(file_path, encoding="utf-8") as f:
            raw_text = f.read()
            self.entries = re.split(r'\n(?=Topic:)', raw_text)
        self.tokenizer = tokenizer
        self.block_size = block_size

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        entry = self.entries[idx]
        task_prefix = "summarize: " if "summary" in entry else "define_terms: " if "keywords" in entry else "autocomplete: "
        entry = task_prefix + entry  # PREPEND TASK PREFIX
        encoding = self.tokenizer.encode_plus(
            entry,
            add_special_tokens=True,
            max_length=self.block_size,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        return encoding.input_ids[0], encoding.attention_mask[0]

# MODEL EXT: Multi-Task Learning
class MultiTaskGPT2(nn.Module):
    def __init__(self, base_model, tokenizer):
        super().__init__()
        self.base_model = base_model
        self.heads = nn.ModuleDict({
            'autocomplete': nn.Linear(base_model.config.n_embd, tokenizer.vocab_size),
            'summarize': nn.Linear(base_model.config.n_embd, tokenizer.vocab_size),
            'define_terms': nn.Linear(base_model.config.n_embd, tokenizer.vocab_size)
        })

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        last_hidden_state = outputs.last_hidden_state
        
        task_logits = {}
        for task, head in self.heads.items():
            task_logits[task] = head(last_hidden_state)
        
        return task_logits

model = MultiTaskGPT2(base_model, tokenizer).to(device)

# TRAIN LOOP (DA GOOD STUFF)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
    
    for input_ids, attention_mask in progress_bar:
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)

        task_logits = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = 0
        for task, logits in task_logits.items():
            # Compute loss for each task...using CrossEntropyLoss
            loss += nn.CrossEntropyLoss()(logits.view(-1, tokenizer.vocab_size), input_ids.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(data_loader)
    print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss:.4f}")

Epoch 1/3:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
from sklearn.metrics import accuracy_score
import numpy as np

def validate_model(model, data_loader, device, tokenizer):
    model.eval()  # MY BOY IN EVAL MODE
    total_loss = 0
    total_accuracy = {task: [] for task in model.heads.keys()}
    loss_fn = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for input_ids, attention_mask in data_loader:
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)

            # FORWARD PASS THROUGH MODEL
            task_logits = model(input_ids=input_ids, attention_mask=attention_mask)
            for task, logits in task_logits.items():
                
                loss = loss_fn(logits.view(-1, tokenizer.vocab_size), input_ids.view(-1))
                total_loss += loss.item()

                
                _, predicted = torch.max(logits, -1)
                correct = (predicted == input_ids).float() 
                accuracy = correct.sum() / correct.numel()
                total_accuracy[task].append(accuracy.item())

    avg_loss = total_loss / len(data_loader)
    avg_accuracy = {task: np.mean(acc) for task, acc in total_accuracy.items()}
    return avg_loss, avg_accuracy


validation_dataset = NotesDataset('./data/ValidationNotes.txt', tokenizer)
validation_loader = DataLoader(validation_dataset, batch_size=2, shuffle=False)

# VALIDATE MODEL
validation_loss, validation_accuracy = validate_model(model, validation_loader, device, tokenizer)
print(f"Validation Loss: {validation_loss:.4f}")
for task, acc in validation_accuracy.items():
    print(f"Average Accuracy for {task}: {acc:.2f}%")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(epoch_losses, label='Training Loss')
plt.title('Training Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.legend()
plt.show()

In [None]:
# save_directory = r'D:/gpt2-large'
# tokenizer.save_pretrained(save_directory)
# model.save_pretrained(save_directory)
print("Training Complete.")