# Continual Learning Modeling Tasks - Reducing Catastrophic Forgetting

## Introduction
Overview and Context

Continual learning in LLMs aims to enable these models to learn new tasks and adapt to new data without forgetting previously learned information. This project addresses the challenge of catastrophic forgetting by enhancing GPT models with continual learning capabilities. This advancement has significant potential applications in automated customer service and dynamic content creation.

Goal

To explore methods for enabling LLMs to continually learn and adapt to new data or tasks without forgetting previously learned information, thereby addressing catastrophic forgetting.

Objectives
1. Mitigate Catastrophic Forgetting: Implement and test Elastic Weight Consolidation (EWC) on GPT-2.
2. Adapt GPT for Continual Learning: Integrate continual learning mechanisms within the Transformer architecture.
3. Evaluate Model Performance: Use backward and forward transfer metrics to measure performance on old vs. new tasks.
4. Understand Transformer Architecture: Explore the self-attention mechanisms and their scalability in transformers.

## Dataset Description
Dataset: WikiText-103
• Description: A collection of over 100 million tokens from verified Good and Featured articles on Wikipedia.
• Usage: To test the model’s ability to learn continually and adapt over time.


## Step 1: Elastic Weight Consolidation (EWC)

- **Model Setup**: Load GPT-2 model.

In [24]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from torch.utils.data import DataLoader

In [2]:
model_name = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

In [3]:
# Add pad token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))

Embedding(50258, 768)

#### Read the Dataset

In [4]:
# Load dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")

# Select a smaller subset of the dataset
small_dataset = dataset['train'].select(range(50)) 

# Convert the dataset to DataLoader
train_dataloader = DataLoader(small_dataset, batch_size=4, shuffle=True)

**Fisher Information Matrix Calculation:** calculating the Fisher Information Matrix (FIM) is an essential part of the Elastic Weight Consolidation (EWC) method used to mitigate catastrophic forgetting in neural networks.

In [5]:
# Define Fisher Information Matrix Calculation
def compute_fisher_information(model, dataloader):
    model.eval()
    fisher_information = {n: torch.zeros(p.shape).to(p.device) for n, p in model.named_parameters() if p.requires_grad}
    criterion = nn.CrossEntropyLoss()
    
    for batch in dataloader:
        inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        targets = inputs['input_ids'].clone()
        
        if targets.numel() == 0:  # Check for empty sequences
            continue
        
        model.zero_grad()
        outputs = model(**inputs)
        loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
        loss.backward()

        for n, p in model.named_parameters():
            if p.requires_grad:
                fisher_information[n] += p.grad.pow(2)
    
    for n in fisher_information:
        fisher_information[n] /= len(dataloader)
    
    return fisher_information

- **EWC Regularization**

In [6]:
# Define EWC Regularization Term in the Loss Function
class EWC:
    def __init__(self, model, dataloader, lambda_=0.4):
        self.model = model
        self.lambda_ = lambda_
        self.fisher_information = compute_fisher_information(model, dataloader)
        self.optimal_params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}

    def penalty(self, model):
        loss = 0
        for n, p in model.named_parameters():
            if p.requires_grad:
                _loss = self.fisher_information[n] * (p - self.optimal_params[n]).pow(2)
                loss += _loss.sum()
        return (self.lambda_ / 2) * loss

In [7]:
# Define evaluation function for old task
def evaluate_on_old_task(model, dataloader):
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            targets = inputs['input_ids'].clone()
            if targets.numel() == 0:
                continue
            
            outputs = model(**inputs)
            loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Average Loss on Old Task: {avg_loss}")
    return avg_loss

In [8]:
# Evaluate initial performance on the old task
initial_loss_old_task = evaluate_on_old_task(model, train_dataloader)  

Average Loss on Old Task: 45.74768932049091


Without any optimizer and EWC training on the model, the average loss is relatively high (45.75).

In [9]:
# Train model on new task without EWC
def train_without_ewc(model, dataloader, optimizer, epochs=3):
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(epochs):
        for batch in dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            targets = inputs['input_ids'].clone()
            if targets.numel() == 0:
                continue
            
            optimizer.zero_grad()
            outputs = model(**inputs)
            loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
            loss.backward()
            optimizer.step()
            
            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

In [10]:
# Initialize optimizer for training without EWC
optimizer_without_ewc = optim.Adam(model.parameters(), lr=5e-5)

In [11]:
# Train the model on the new task
train_without_ewc(model, train_dataloader, optimizer_without_ewc)

Epoch 1, Loss: 36.986759185791016
Epoch 1, Loss: 34.90500259399414
Epoch 1, Loss: 32.675880432128906
Epoch 1, Loss: 27.8104305267334
Epoch 1, Loss: 14.99581241607666
Epoch 1, Loss: 17.95258331298828
Epoch 1, Loss: 9.375332832336426
Epoch 1, Loss: 11.026851654052734
Epoch 1, Loss: 6.626292705535889
Epoch 1, Loss: 2.9959139823913574
Epoch 1, Loss: 4.661403656005859
Epoch 1, Loss: 7.900046348571777
Epoch 1, Loss: 4.165240287780762
Epoch 2, Loss: 6.586829662322998
Epoch 2, Loss: 3.544839382171631
Epoch 2, Loss: 3.9093141555786133
Epoch 2, Loss: 2.3981270790100098
Epoch 2, Loss: 4.0951433181762695
Epoch 2, Loss: 1.8137054443359375
Epoch 2, Loss: 6.375125885009766
Epoch 2, Loss: 3.5391931533813477
Epoch 2, Loss: 4.002185344696045
Epoch 2, Loss: 3.7483954429626465
Epoch 2, Loss: 4.445273399353027
Epoch 2, Loss: 2.932441473007202
Epoch 3, Loss: 3.5660438537597656
Epoch 3, Loss: 3.4430465698242188
Epoch 3, Loss: 1.8920410871505737
Epoch 3, Loss: 3.4411191940307617
Epoch 3, Loss: 2.6871933937072

In [12]:
# Evaluate performance on the old task after training on the new task without EWC
loss_after_training_without_ewc = evaluate_on_old_task(model, train_dataloader)

Average Loss on Old Task: 2.7562116292806773


After initializing the optimizer for training the model, the average loss decreased rapidly (2.76).

- **Training**: Implement training with EWC.

In [13]:
# Training Loop with EWC
def train_with_ewc(model, train_dataloader, ewc, optimizer, epochs=3):
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(epochs):
        for batch in train_dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            targets = inputs['input_ids'].clone()
            
            if targets.numel() == 0:  # Check for empty sequences
                continue
            
            outputs = model(**inputs)
            loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
            ewc_loss = ewc.penalty(model)
            total_loss = loss + ewc_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            print(f"Epoch {epoch+1}, Loss: {total_loss.item()}")

In [14]:
# Evaluate Model Performance
def evaluate_model(model, dataloader):
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            targets = inputs['input_ids'].clone()
            
            if targets.numel() == 0:  # Check for empty sequences
                continue
            
            outputs = model(**inputs)
            loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Average Loss: {avg_loss}")
    return avg_loss


In [15]:
# Initialize EWC
ewc = EWC(model, train_dataloader)
optimizer = optim.Adam(model.parameters(), lr=5e-5)

# Train the model with EWC
train_with_ewc(model, train_dataloader, ewc, optimizer)

# Evaluate the model
evaluate_model(model, train_dataloader)

Epoch 1, Loss: 2.975815534591675
Epoch 1, Loss: 3.889411687850952
Epoch 1, Loss: 1.4942246675491333
Epoch 1, Loss: 2.014646291732788
Epoch 1, Loss: 1.7066130638122559
Epoch 1, Loss: 1.3164194822311401
Epoch 1, Loss: 0.9031688570976257
Epoch 1, Loss: 1.8186219930648804
Epoch 1, Loss: 0.8964985013008118
Epoch 1, Loss: 0.6037851572036743
Epoch 1, Loss: 0.6174091696739197
Epoch 1, Loss: 0.29234981536865234
Epoch 1, Loss: 0.806888222694397
Epoch 2, Loss: 1.5767492055892944
Epoch 2, Loss: 0.39598917961120605
Epoch 2, Loss: 0.516277551651001
Epoch 2, Loss: 0.3632330894470215
Epoch 2, Loss: 0.6245399117469788
Epoch 2, Loss: 0.25948378443717957
Epoch 2, Loss: 0.18636707961559296
Epoch 2, Loss: 0.4856903553009033
Epoch 2, Loss: 0.19470864534378052
Epoch 2, Loss: 0.22150105237960815
Epoch 2, Loss: 0.9134004712104797
Epoch 2, Loss: 0.34479355812072754
Epoch 2, Loss: 0.23868535459041595
Epoch 3, Loss: 0.1113562360405922
Epoch 3, Loss: 0.37803328037261963
Epoch 3, Loss: 0.11955796927213669
Epoch 3, 

0.11365048300761443

After using EWC, the performance is better, the average loss becomes 0.113, however, the training time is 8 minutes 50 secondes, on 50 sample dataset.

## Step 2: Progressive Prompts

#### Define the progressive prompt class

In [43]:
class ProgressivePrompt(nn.Module):
    def __init__(self, model, tokenizer):
        super(ProgressivePrompt, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.prompt_embeddings = nn.ParameterList()
    
    def add_prompt(self, new_prompt):
        prompt_ids = self.tokenizer(new_prompt, return_tensors='pt').input_ids
        prompt_embeddings = self.model.transformer.wte(prompt_ids)
        self.prompt_embeddings.append(nn.Parameter(prompt_embeddings.squeeze(0)))

    def forward(self, input_ids):
        batch_size = input_ids.size(0)
        prompt_embeds = torch.cat([prompt_embed.unsqueeze(0) for prompt_embed in self.prompt_embeddings], dim=1)
        prompt_embeds = prompt_embeds.expand(batch_size, -1, -1)
        inputs_embeds = self.model.transformer.wte(input_ids)
        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)
        labels = torch.cat([torch.full((batch_size, prompt_embeds.size(1)), -100).to(input_ids.device), input_ids], dim=1)
        return self.model(inputs_embeds=inputs_embeds, labels=labels)

    def get_combined_prompt_text(self):
        return " ".join([self.tokenizer.decode(prompt_embed.detach().cpu().numpy()) for prompt_embed in self.prompt_embeddings])

#### Fine-tune the model with trainable prompts

In [44]:
def fine_tune_with_progressive_prompts(model, tokenizer, data_loader, progressive_prompt, new_prompt_text, epochs=3, lr=5e-5):
    # Add the new prompt to the model
    progressive_prompt.add_prompt(new_prompt_text)
    
    optimizer = AdamW(progressive_prompt.parameters(), lr=lr)
    
    model.train()
    for epoch in range(epochs):
        for batch in data_loader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True).input_ids
            inputs = inputs.to(model.device).long()
            
            optimizer.zero_grad()
            
            outputs = progressive_prompt(inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
    
    return model

In [45]:
# Initialize progressive prompts
progressive_prompt = ProgressivePrompt(model, tokenizer)

# Fine-tune the model on the new task
new_prompt_text = "This is a new task prompt."
model = fine_tune_with_progressive_prompts(model, tokenizer, train_dataloader, progressive_prompt, new_prompt_text)

# Evaluate the model
test_data = ["Some test sentence for evaluation."]
evaluate_model(model, tokenizer, test_data)

Epoch 1, Loss: 7.239480972290039
Epoch 1, Loss: 8.5606689453125
Epoch 1, Loss: 8.844741821289062
Epoch 1, Loss: 9.052119255065918
Epoch 1, Loss: 6.741456508636475
Epoch 1, Loss: 13.632518768310547
Epoch 1, Loss: 3.760716199874878
Epoch 1, Loss: 6.807703495025635
Epoch 1, Loss: 8.914515495300293
Epoch 1, Loss: 6.7754364013671875
Epoch 1, Loss: 7.414560317993164
Epoch 1, Loss: 5.2529168128967285
Epoch 1, Loss: 21.523788452148438
Epoch 2, Loss: 4.346595287322998
Epoch 2, Loss: 5.344362258911133
Epoch 2, Loss: 17.516157150268555
Epoch 2, Loss: 10.0841646194458
Epoch 2, Loss: 8.682435989379883
Epoch 2, Loss: 8.15819263458252
Epoch 2, Loss: 3.232306480407715
Epoch 2, Loss: 5.076308727264404
Epoch 2, Loss: 7.424879550933838
Epoch 2, Loss: 9.098407745361328
Epoch 2, Loss: 25.525550842285156
Epoch 2, Loss: 13.476286888122559
Epoch 2, Loss: 20.564903259277344
Epoch 3, Loss: 12.324183464050293
Epoch 3, Loss: 5.501623153686523
Epoch 3, Loss: 12.944634437561035
Epoch 3, Loss: 17.254173278808594
Epo

12.935417175292969

## Step 3: Low-Rank Adaptation (LoRA)
- **Model Setup**: Apply LoRA to GPT-2.
- **Training**: Code for fine-tuning with LoRA.

In [None]:
from lora import apply_lora

model = apply_lora(model, rank=4)

In [None]:
for epoch in range(num_epochs):
    for input_text in dataset:
        inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True)
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

## Evaluation
- **Metrics**: Define and calculate backward transfer, forward transfer, perplexity, and other metrics.
- **Results**: Present results for each method and compare.

In [None]:
def calculate_metrics(model, old_task_dataset, new_task_dataset):
    # Implement backward transfer, forward transfer, and perplexity calculations
    pass

old_task_metrics = calculate_metrics(model, old_task_dataset, new_task_dataset)
new_task_metrics = calculate_metrics(model, new_task_dataset, old_task_dataset)

In [None]:
import matplotlib.pyplot as plt

def plot_results(old_task_metrics, new_task_metrics):
    metrics = ['backward_transfer', 'forward_transfer', 'perplexity']
    for metric in metrics:
        plt.figure()
        plt.plot(old_task_metrics[metric], label='Old Task')
        plt.plot(new_task_metrics[metric], label='New Task')
        plt.xlabel('Epoch')
        plt.ylabel(metric)
        plt.legend()
        plt.title(f'{metric} over time')
        plt.show()

plot_results(old_task_metrics, new_task_metrics)