# Continual Learning Modeling Tasks - Reducing Catastrophic Forgetting

## Introduction
Overview and Context

Continual learning in LLMs aims to enable these models to learn new tasks and adapt to new data without forgetting previously learned information. This project addresses the challenge of catastrophic forgetting by enhancing GPT models with continual learning capabilities. This advancement has significant potential applications in automated customer service and dynamic content creation.

Goal

To explore methods for enabling LLMs to continually learn and adapt to new data or tasks without forgetting previously learned information, thereby addressing catastrophic forgetting.

Objectives
1. Mitigate Catastrophic Forgetting: Implement and test Elastic Weight Consolidation (EWC) on GPT-2.
2. Adapt GPT for Continual Learning: Integrate continual learning mechanisms within the Transformer architecture.
3. Evaluate Model Performance: Use backward and forward transfer metrics to measure performance on old vs. new tasks.
4. Understand Transformer Architecture: Explore the self-attention mechanisms and their scalability in transformers.

## Dataset Description
Dataset: WikiText-103
• Description: A collection of over 100 million tokens from verified Good and Featured articles on Wikipedia.
• Usage: To test the model’s ability to learn continually and adapt over time.


## Step 1: Elastic Weight Consolidation (EWC)

- **Model Setup**: Load GPT-2 model.

In [51]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from torch.utils.data import DataLoader

In [52]:
model_name = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

In [53]:
# Add pad token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))

Embedding(50258, 768)

#### Read the Dataset

In [54]:
# Load dataset
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")

# Select a smaller subset of the dataset
small_dataset = dataset['train'].select(range(100)) 

# Convert the dataset to DataLoader
train_dataloader = DataLoader(small_dataset, batch_size=4, shuffle=True)

**Fisher Information Matrix Calculation:** calculating the Fisher Information Matrix (FIM) is an essential part of the Elastic Weight Consolidation (EWC) method used to mitigate catastrophic forgetting in neural networks.

In [55]:
# Define Fisher Information Matrix Calculation
def compute_fisher_information(model, dataloader):
    model.eval()
    fisher_information = {n: torch.zeros(p.shape).to(p.device) for n, p in model.named_parameters() if p.requires_grad}
    criterion = nn.CrossEntropyLoss()
    
    for batch in dataloader:
        inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        targets = inputs['input_ids'].clone()
        
        if targets.numel() == 0:  # Check for empty sequences
            continue
        
        model.zero_grad()
        outputs = model(**inputs)
        loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
        loss.backward()

        for n, p in model.named_parameters():
            if p.requires_grad:
                fisher_information[n] += p.grad.pow(2)
    
    for n in fisher_information:
        fisher_information[n] /= len(dataloader)
    
    return fisher_information

- **EWC Regularization**

In [56]:
# Define EWC Regularization Term in the Loss Function
class EWC:
    def __init__(self, model, dataloader, lambda_=0.4):
        self.model = model
        self.lambda_ = lambda_
        self.fisher_information = compute_fisher_information(model, dataloader)
        self.optimal_params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}

    def penalty(self, model):
        loss = 0
        for n, p in model.named_parameters():
            if p.requires_grad:
                _loss = self.fisher_information[n] * (p - self.optimal_params[n]).pow(2)
                loss += _loss.sum()
        return (self.lambda_ / 2) * loss

In [57]:
# Define evaluation function for old task
def evaluate_on_old_task(model, dataloader):
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            targets = inputs['input_ids'].clone()
            if targets.numel() == 0:
                continue
            
            outputs = model(**inputs)
            loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Average Loss on Old Task: {avg_loss}")
    return avg_loss

In [58]:
# Evaluate initial performance on the old task
initial_loss_old_task = evaluate_on_old_task(model, train_dataloader)  

Average Loss on Old Task: 36.94977851867676


Without any optimizer and EWC training on the model, the average loss is relatively high (45.75 for 50 sample dataset, 36.95 for 100 sample dataset)

In [59]:
# Train model on new task without EWC
def train_without_ewc(model, dataloader, optimizer, epochs=3):
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(epochs):
        for batch in dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            targets = inputs['input_ids'].clone()
            if targets.numel() == 0:
                continue
            
            optimizer.zero_grad()
            outputs = model(**inputs)
            loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
            loss.backward()
            optimizer.step()
            
            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

In [60]:
# Initialize optimizer for training without EWC
optimizer_without_ewc = optim.Adam(model.parameters(), lr=5e-5)

In [61]:
# Train the model on the new task
train_without_ewc(model, train_dataloader, optimizer_without_ewc)

Epoch 1, Loss: 39.88058090209961
Epoch 1, Loss: 30.928239822387695
Epoch 1, Loss: 23.43995475769043
Epoch 1, Loss: 18.15724754333496
Epoch 1, Loss: 9.123306274414062
Epoch 1, Loss: 10.308950424194336
Epoch 1, Loss: 11.038313865661621
Epoch 1, Loss: 6.5043253898620605
Epoch 1, Loss: 10.436527252197266
Epoch 1, Loss: 6.038282871246338
Epoch 1, Loss: 4.376773357391357
Epoch 1, Loss: 6.483856678009033
Epoch 1, Loss: 4.479620456695557
Epoch 1, Loss: 3.5722339153289795
Epoch 1, Loss: 3.2003958225250244
Epoch 1, Loss: 4.9409894943237305
Epoch 1, Loss: 4.029125690460205
Epoch 1, Loss: 2.415905475616455
Epoch 1, Loss: 7.3519415855407715
Epoch 1, Loss: 3.558713436126709
Epoch 1, Loss: 3.0454163551330566
Epoch 1, Loss: 5.623571395874023
Epoch 1, Loss: 2.1449193954467773
Epoch 1, Loss: 6.932829856872559
Epoch 1, Loss: 3.3559679985046387
Epoch 2, Loss: 2.047531843185425
Epoch 2, Loss: 4.4559407234191895
Epoch 2, Loss: 3.3338167667388916
Epoch 2, Loss: 2.953453540802002
Epoch 2, Loss: 2.510204553604

In [62]:
# Evaluate performance on the old task after training on the new task without EWC
loss_after_training_without_ewc = evaluate_on_old_task(model, train_dataloader)

Average Loss on Old Task: 0.30981976598501204


After initializing the optimizer for training the model, the average loss decreased rapidly (2.76 for 50 sample dataset, 0.3 for 100 sample dataset).

- **Training**: Implement training with EWC.

In [63]:
# Training Loop with EWC
def train_with_ewc(model, train_dataloader, ewc, optimizer, epochs=3):
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(epochs):
        for batch in train_dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            targets = inputs['input_ids'].clone()
            
            if targets.numel() == 0:  # Check for empty sequences
                continue
            
            outputs = model(**inputs)
            loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
            ewc_loss = ewc.penalty(model)
            total_loss = loss + ewc_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            print(f"Epoch {epoch+1}, Loss: {total_loss.item()}")

In [64]:
# Evaluate Model Performance
def evaluate_model(model, dataloader):
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            targets = inputs['input_ids'].clone()
            
            if targets.numel() == 0:  # Check for empty sequences
                continue
            
            outputs = model(**inputs)
            loss = criterion(outputs.logits.view(-1, model.config.vocab_size), targets.view(-1))
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Average Loss: {avg_loss}")
    return avg_loss


In [65]:
# Initialize EWC
ewc = EWC(model, train_dataloader)
optimizer = optim.Adam(model.parameters(), lr=5e-5)

# Train the model with EWC
train_with_ewc(model, train_dataloader, ewc, optimizer)

# Evaluate the model
evaluate_model(model, train_dataloader)

Epoch 1, Loss: 0.36142903566360474
Epoch 1, Loss: 0.37427768111228943
Epoch 1, Loss: 0.21519066393375397
Epoch 1, Loss: 0.260184645652771
Epoch 1, Loss: 0.27208632230758667
Epoch 1, Loss: 0.7719264626502991
Epoch 1, Loss: 0.2463414967060089
Epoch 1, Loss: 0.1594952493906021
Epoch 1, Loss: 0.22791917622089386
Epoch 1, Loss: 0.339248925447464
Epoch 1, Loss: 0.1523289531469345
Epoch 1, Loss: 0.7252788543701172
Epoch 1, Loss: 0.26625919342041016
Epoch 1, Loss: 0.1510256677865982
Epoch 1, Loss: 0.2001023143529892
Epoch 1, Loss: 0.1504552662372589
Epoch 1, Loss: 0.24973703920841217
Epoch 1, Loss: 0.1652706414461136
Epoch 1, Loss: 0.1875583976507187
Epoch 1, Loss: 0.15624909102916718
Epoch 1, Loss: 0.1555921733379364
Epoch 1, Loss: 0.11466293036937714
Epoch 1, Loss: 0.1890484094619751
Epoch 1, Loss: 0.09210159629583359
Epoch 1, Loss: 0.1794268935918808
Epoch 2, Loss: 0.13263536989688873
Epoch 2, Loss: 0.13463184237480164
Epoch 2, Loss: 0.07543867826461792
Epoch 2, Loss: 0.1082070842385292
Epo

0.09918553963303565

After using EWC, the performance is better, the average loss becomes 0.113, however, the training time is 8m50s, on 50 sample dataset. On the 100 sample dataset, the average loss is 0.099, the training time is 14m38s.

## Step 2: Progressive Prompts

#### Define the progressive prompt class

In [66]:
class ProgressivePrompt(nn.Module):
    def __init__(self, model, tokenizer):
        super(ProgressivePrompt, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.prompt_embeddings = nn.ParameterList()
    
    def add_prompt(self, new_prompt):
        prompt_ids = self.tokenizer(new_prompt, return_tensors='pt').input_ids
        prompt_embeddings = self.model.transformer.wte(prompt_ids)
        self.prompt_embeddings.append(nn.Parameter(prompt_embeddings.squeeze(0)))

    def forward(self, input_ids):
        batch_size = input_ids.size(0)
        prompt_embeds = torch.cat([prompt_embed.unsqueeze(0) for prompt_embed in self.prompt_embeddings], dim=1)
        prompt_embeds = prompt_embeds.expand(batch_size, -1, -1)
        inputs_embeds = self.model.transformer.wte(input_ids)
        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)
        labels = torch.cat([torch.full((batch_size, prompt_embeds.size(1)), -100).to(input_ids.device), input_ids], dim=1)
        return self.model(inputs_embeds=inputs_embeds, labels=labels)

    def get_combined_prompt_text(self):
        return " ".join([self.tokenizer.decode(prompt_embed.detach().cpu().numpy()) for prompt_embed in self.prompt_embeddings])

#### Fine-tune the model with trainable prompts

In [67]:
def fine_tune_with_progressive_prompts(model, tokenizer, data_loader, progressive_prompt, new_prompt_text, epochs=3, lr=5e-5):
    # Add the new prompt to the model
    progressive_prompt.add_prompt(new_prompt_text)
    
    optimizer = AdamW(progressive_prompt.parameters(), lr=lr)
    
    model.train()
    for epoch in range(epochs):
        for batch in data_loader:
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True).input_ids
            inputs = inputs.to(model.device).long()
            
            optimizer.zero_grad()
            
            outputs = progressive_prompt(inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
    
    return model

In [69]:
def evaluate_model(model, tokenizer, test_data):
    model.eval()
    losses = []
    for text in test_data:
        inputs = tokenizer(text, return_tensors='pt').input_ids
        inputs = inputs.to(model.device)
        labels = inputs.clone()
        
        with torch.no_grad():
            outputs = model(inputs, labels=labels)
            loss = outputs.loss.item()
            losses.append(loss)
    
    avg_loss = sum(losses) / len(losses)
    print(f"Average Loss: {avg_loss}")
    return avg_loss

In [70]:
# Initialize progressive prompts
progressive_prompt = ProgressivePrompt(model, tokenizer)

# Fine-tune the model on the new task
new_prompt_text = "This is a new task prompt."
model = fine_tune_with_progressive_prompts(model, tokenizer, train_dataloader, progressive_prompt, new_prompt_text)

# Evaluate the model
test_data = ["Some test sentence for evaluation."]
evaluate_model(model, tokenizer, test_data)

Epoch 1, Loss: 1.7333935499191284
Epoch 1, Loss: 2.808159112930298
Epoch 1, Loss: 2.3256754875183105
Epoch 1, Loss: 2.296097993850708
Epoch 1, Loss: 1.0251165628433228
Epoch 1, Loss: 1.7925424575805664
Epoch 1, Loss: 1.0886930227279663
Epoch 1, Loss: 1.1101934909820557
Epoch 1, Loss: 1.1167850494384766
Epoch 1, Loss: 1.499448537826538
Epoch 1, Loss: 1.4998842477798462
Epoch 1, Loss: 1.2980015277862549
Epoch 1, Loss: 1.7631909847259521
Epoch 1, Loss: 2.013593912124634
Epoch 1, Loss: 2.419537305831909
Epoch 1, Loss: 1.0160659551620483
Epoch 1, Loss: 1.3119125366210938
Epoch 1, Loss: 1.5501508712768555
Epoch 1, Loss: 1.0795613527297974
Epoch 1, Loss: 1.796526551246643
Epoch 1, Loss: 1.7399340867996216
Epoch 1, Loss: 1.2518612146377563
Epoch 1, Loss: 1.521407961845398
Epoch 1, Loss: 2.1559202671051025
Epoch 1, Loss: 2.4155631065368652
Epoch 2, Loss: 1.6019731760025024
Epoch 2, Loss: 1.3940032720565796
Epoch 2, Loss: 1.4069116115570068
Epoch 2, Loss: 1.085863471031189
Epoch 2, Loss: 1.33255

7.695197105407715

12.93 of average loss for 50 sample dataset, 7.69 for 100 sample dataset.

## Step 3: Low-Rank Adaptation (LoRA)
- **Model Setup**: Apply LoRA to GPT-2.
- **Training**: Code for fine-tuning with LoRA.

In [None]:
from lora import apply_lora

model = apply_lora(model, rank=4)

## Evaluation
- **Metrics**: Define and calculate backward transfer, forward transfer, perplexity, and other metrics.
- **Results**: Present results for each method and compare.

In [None]:
def calculate_metrics(model, old_task_dataset, new_task_dataset):
    # Implement backward transfer, forward transfer, and perplexity calculations
    pass

old_task_metrics = calculate_metrics(model, old_task_dataset, new_task_dataset)
new_task_metrics = calculate_metrics(model, new_task_dataset, old_task_dataset)

In [None]:
import matplotlib.pyplot as plt

def plot_results(old_task_metrics, new_task_metrics):
    metrics = ['backward_transfer', 'forward_transfer', 'perplexity']
    for metric in metrics:
        plt.figure()
        plt.plot(old_task_metrics[metric], label='Old Task')
        plt.plot(new_task_metrics[metric], label='New Task')
        plt.xlabel('Epoch')
        plt.ylabel(metric)
        plt.legend()
        plt.title(f'{metric} over time')
        plt.show()

plot_results(old_task_metrics, new_task_metrics)