In [1]:
import torch
import pyro
import tyxe

import random
import functools
import copy

import numpy as np

from pyro.infer import SVI, TraceMeanField_ELBO, Trace_ELBO

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset

from datasets import load_dataset  # Added to load SuperNI dataset

from typing import Optional, List
from model.mle_prior import MLEPrior


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
import torch

# Check if CUDA is available
print("CUDA Available:", torch.cuda.is_available())

# Get the current device index
current_device = torch.cuda.current_device()
print("Current Device Index:", current_device)

# Get the name of the current device
device_name = torch.cuda.get_device_name(current_device)
print("Current Device Name:", device_name)

# Get the number of GPUs
num_gpus = torch.cuda.device_count()
print("Number of GPUs:", num_gpus)

# List all GPUs
for device_id in range(num_gpus):
    print(f"Device {device_id}: {torch.cuda.get_device_name(device_id)}")


CUDA Available: True
Current Device Index: 0
Current Device Name: NVIDIA A100-SXM4-80GB
Number of GPUs: 1
Device 0: NVIDIA A100-SXM4-80GB


In [4]:
# def fetch_nlp_datasets(tokenizer, batch_size, num_tasks, start_task=1):
#     train_loaders = []
#     test_loaders = []

#     # Load the SuperNI dataset
#     # You can specify the split and tasks you need
#     superni_dataset = load_dataset('super_glue', 'ni')  # Adjust if necessary

#     # Assuming tasks are numbered starting from 1
#     for task_index in range(start_task, num_tasks + 1):
#         if task_index == 1:
#             # Load QA task from SuperNI
#             train_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qa', split='train')
#             test_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qa', split='validation')
#         elif task_index == 2:
#             # Load QG task from SuperNI
#             train_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qg', split='train')
#             test_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qg', split='validation')
#         else:
#             # Load additional tasks if needed
#             pass

#         train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#         test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#         train_loaders.append(train_loader)
#         test_loaders.append(test_loader)

#     return train_loaders, test_loaders


In [5]:
# import os
# def load_superni_task_dataset(tokenizer, task_type='qa', split='train'):
#     # Filter the dataset for the specific task type
#     # SuperNI tasks are identified by their task names or IDs
#     # For example, you can filter tasks that contain 'question answering' or 'question generation'

#     # Example of filtering:
#     if task_type == 'qa':
#         path='/home/pranav24/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting/QA_FineTuned'
#         os.chdir(path)
#         target_file = r"task024_cosmosqa_answer_generation.json"
#         with open(target_file, 'r', encoding='utf-8-sig') as f:
#             json_data = json.load(f)

#         dataset = json_data['Instances'][0:2223]
        
#         # task_filter = lambda ex: 'question answering' in ex['Task']
#     elif task_type == 'qg':
#         # task_filter = lambda ex: 'question generation' in ex['Task']
#         path='/home/pranav24/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting/QG_FineTuned/QG_FineTuned'
#         os.chdir(path)
#         target_file = r"task074_squad1.1_question_generation.json"
#         with open(target_file, 'r', encoding='utf-8-sig') as f:
#             json_data = json.load(f)

#         dataset = json_data['Instances'][0:2223]
#     else:
#         raise ValueError(f"Unsupported task type: {task_type}")


#     def preprocess_function(examples):
#         # For SuperNI, inputs and outputs are in 'Input' and 'Output' fields
#         inputs = examples['Input']
#         targets = examples['Output']
    
#         # Tokenize inputs and targets
#         model_inputs = tokenizer(inputs, truncation=True, padding='max_length', max_length=512)
#         with tokenizer.as_target_tokenizer():
#             labels = tokenizer(targets, truncation=True, padding='max_length', max_length=512)
    
#         model_inputs['labels'] = labels['input_ids']
#         return model_inputs
    
#     dataset = dataset.map(preprocess_function, batched=True)
#     dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    
#     return dataset


In [6]:
# def get_data_loader_for_task1(tokenizer, batch_size):
#     # Load the SuperNI dataset

#     # Load QA task
#     train_dataset = load_superni_task_dataset(tokenizer, task_type='qa', split='train')
#     test_dataset = load_superni_task_dataset(tokenizer, task_type='qa', split='validation')

#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#     test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#     return train_loader, test_loader


### Task1 -QA LoRA+EVCL

In [7]:
# class TrainingConfig:
#     def __init__(
#         self,
#         output_dir,
#         num_train_epochs,
#         per_device_train_batch_size,
#         gradient_accumulation_steps,
#         learning_rate,
#         logging_steps,
#         eval_steps,
#         save_steps,
#         save_total_limit,
#         fp16,
#     ):
#         self.output_dir = output_dir
#         self.num_train_epochs = num_train_epochs
#         self.per_device_train_batch_size = per_device_train_batch_size
#         self.gradient_accumulation_steps = gradient_accumulation_steps
#         self.learning_rate = learning_rate
#         self.logging_steps = logging_steps
#         self.eval_steps = eval_steps
#         self.save_steps = save_steps
#         self.save_total_limit = save_total_limit
#         self.fp16 = fp16


In [20]:
def compute_fisher_info(
    model, 
    data_loader, 
    prev_fisher_info=None, 
    ewc_gamma=1.0, 
    num_epochs=1, 
    head_modules=None, 
    n_samples=None
):

    fisher = {}
    
    # Initialize Fisher matrix for LoRA parameters, excluding head modules if provided
    for name, param in model.named_parameters():
        if 'lora' in name and (head_modules is None or not any(name.startswith(head) for head in head_modules)):
            fisher[name] = torch.zeros_like(param).to(DEVICE)
    
    # Save the model's current training state and set to eval
    old_training_state = model.training
    model.eval()
    
    scaler = GradScaler(device='cuda')

    batch_count = 0

    for epoch in range(num_epochs):
        print(f"Starting Epoch {epoch + 1}/{num_epochs}")
        for i, batch in enumerate(data_loader):
            if n_samples is not None and batch_count >= n_samples:
                break

            print(f"Processing batch {batch_count + 1}")
            model.zero_grad()
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            try:
                # with autocast(device_type='cuda'):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                loss.backward()
            # scaler.scale(loss).backward()
            except RuntimeError as e:
                print(f"Error in batch {batch_count + 1}: {e}")
                break

            # Accumulate Fisher information for LoRA parameters
            for name, param in model.named_parameters():
                if 'lora' in name and param.grad is not None and (head_modules is None or not any(name.startswith(head) for head in head_modules)):
                    fisher[name] += param.grad.data ** 2

            print(f"Completed batch {batch_count + 1}")
            batch_count += 1

    # Normalize Fisher information by the number of processed batches or samples
    normalization_factor = batch_count if n_samples is None else min(batch_count, n_samples)
    for name in fisher:
        fisher[name] = fisher[name] / normalization_factor

    # Integrate previous Fisher information with EWC scaling
    if prev_fisher_info is not None:
        for name in fisher:
            if name in prev_fisher_info:
                fisher[name] += ewc_gamma * prev_fisher_info[name]

    # Restore the model's original training state
    model.train(old_training_state)
    
    return fisher

# Function to get variational posterior means
def get_variational_posterior_means():
    posterior_means = {}
    for name, module in model.named_modules():
        if hasattr(module, 'lora_A'):
            lora_A_loc = pyro.param(f"{name}.lora_A_loc").detach().clone()
            posterior_means[f"{name}.lora_A"] = lora_A_loc
        if hasattr(module, 'lora_B'):
            lora_B_loc = pyro.param(f"{name}.lora_B_loc").detach().clone()
            posterior_means[f"{name}.lora_B"] = lora_B_loc
    return posterior_means


In [4]:
from peft.tuners.lora import LoraLayer

In [5]:
import os
import torch
import zipfile
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model
from accelerate import init_empty_weights
from datasets import Dataset
from huggingface_hub import login
from peft.tuners.lora import LoraLayer
from pyro.nn.module import to_pyro_module_


def initialize_lora():
    login("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")
    # Set environment variable to manage memory fragmentation
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    
     
    # Specify directories and the path to the zip file
    offload_dir = os.path.expanduser("llama_offload_evcl/")
     
    os.makedirs(offload_dir, exist_ok=True)
     
    # Extract only the specified JSON file from the zip archive
    os.chdir('/home/kakde2/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting/SSR/Latest_Weights/QA_Weights')
    target_file = "task024_cosmosqa_answer_generation.json"
     
    # Load tokenizer from Hugging Face
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
    tokenizer.pad_token = tokenizer.eos_token


    # Load the model with accelerate's offloading and device map auto-setup
    with init_empty_weights():
        model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Meta-Llama-3-8B",
            device_map="auto",
            # max_memory=max_memory,
            offload_folder=offload_dir,
            load_in_8bit=True,
            llm_int8_enable_fp32_cpu_offload=True
        )
     
    # Configure LoRA with reduced rank
    lora_config = LoraConfig(
        r=4,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)

    #printing the trainable parameters
    model.print_trainable_parameters()

    # for name, param in model.named_parameters():
    #     if 'lora' in name:
    #         print(name)

    return model, tokenizer

    

In [6]:
print("Loading base model...")
model,tokenizer=initialize_lora()

Loading base model...


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 1,703,936 || all params: 8,031,965,184 || trainable%: 0.0212


In [7]:
os.chdir('/home/kakde2/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting/SSR/Latest_Weights/QA_Weights')
target_file = "task024_cosmosqa_answer_generation.json"

with open(target_file, 'r', encoding='utf-8-sig') as f:
    json_data = json.load(f)

instances = json_data['Instances'][0:2223]
input_texts = [str(instance['input']) for instance in instances]
output_texts = [str(instance['output'][0]) if instance['output'] else "" for instance in instances]

# Create Hugging Face Dataset
ds = Dataset.from_dict({'input': input_texts, 'output': output_texts})

# Tokenize the dataset
def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["input"],
        truncation=True,
        padding="max_length",
        max_length=512
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["output"],
            truncation=True,
            padding="max_length",
            max_length=512
        )
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["attention_mask"] = model_inputs.get("attention_mask", None)
    return model_inputs

# Apply tokenization and set format
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["input", "output"])
tokenized_datasets.set_format("torch")

# Split dataset into train and eval
train_size = int(0.9 * len(tokenized_datasets))
train_dataset = tokenized_datasets.select(range(train_size))
eval_dataset = tokenized_datasets.select(range(train_size, len(tokenized_datasets)))

# Create DataLoaders
batch_size = 8  
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

# Define data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Map:   0%|          | 0/2223 [00:00<?, ? examples/s]



In [8]:
def save_trained_model(model, tokenizer, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    model.save_pretrained(output_dir)

    tokenizer.save_pretrained(output_dir)
    print(f"Model and tokenizer saved to {output_dir}")

In [9]:
def evaluate_model(model, eval_loader):
    model.eval()
    total_loss = 0.0
    num_batches = 0
    with torch.no_grad():
        for batch in eval_loader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
            loss = outputs.loss
            total_loss += loss.item()
            num_batches += 1
    avg_loss = total_loss / num_batches
    print(f"Evaluation Loss: {avg_loss:.4f}")

In [21]:
import pyro.distributions as dist
import pyro.poutine as poutine

def run_lora_evcl_1(
    num_epochs: int = 3,
    base_model_name: str = "meta-llama/Meta-Llama-3-8B",
    batch_size: int = 2,
    learning_rate: float = 1e-5,
    logging_steps: int = 100,
    eval_steps: int = 200,
    save_steps: int = 500,
    output_dir: str = "finetuned-weights-LoRA-EVCL",
):


    def bayesian_guide(input_ids, attention_mask, labels):
        # Define variational distributions over the LoRA parameters
        for name, module in model.named_modules():
            if hasattr(module, 'lora_A'):
                for key in module.lora_A:
                    param_name = f"{name}.lora_A.{key}"
                    lora_A_param = module.lora_A[key].weight
                    device = lora_A_param.device
                    loc = pyro.param(
                        f"{param_name}_loc",
                        lora_A_param.clone().detach().to(device)
                    )
                    scale = pyro.param(
                        f"{param_name}_scale",
                        (0.1 * torch.ones_like(lora_A_param)).to(device),
                        constraint=dist.constraints.positive
                    )
                    pyro.sample(
                        param_name,
                        dist.Normal(loc, scale).to_event(lora_A_param.dim())
                    )
            if hasattr(module, 'lora_B'):
                for key in module.lora_B:
                    param_name = f"{name}.lora_B.{key}"
                    lora_B_param = module.lora_B[key].weight
                    device = lora_B_param.device
                    loc = pyro.param(
                        f"{param_name}_loc",
                        lora_B_param.clone().detach().to(device)
                    )
                    scale = pyro.param(
                        f"{param_name}_scale",
                        (0.1 * torch.ones_like(lora_B_param)).to(device),
                        constraint=dist.constraints.positive
                    )
                    pyro.sample(
                        param_name,
                        dist.Normal(loc, scale).to_event(lora_B_param.dim())
                    )
                    
    def bayesian_model(input_ids, attention_mask, labels):
        # Define a function to sample and substitute LoRA parameters
        def model_with_sampled_lora():
            # Sample LoRA parameters and set them in the model
            for name, module in model.named_modules():
                if hasattr(module, 'lora_A'):
                    for key in module.lora_A:
                        param_name = f"{name}.lora_A.{key}"
                        lora_A_module = module.lora_A[key]
                        device = lora_A_module.weight.device
    
                        # Sample from the prior
                        sampled_weight = pyro.sample(
                            param_name,
                            dist.Normal(
                                lora_A_module.weight.detach().to(device),
                                (0.1 * torch.ones_like(lora_A_module.weight)).to(device)
                            ).to_event(lora_A_module.weight.dim())
                        )
    
                        # Assign the sampled weight to the module
                        # lora_A_module.weight = torch.nn.Parameter(sampled_weight)
                        lora_A_module.weight.data.copy_(sampled_weight)
    
                if hasattr(module, 'lora_B'):
                    for key in module.lora_B:
                        param_name = f"{name}.lora_B.{key}"
                        lora_B_module = module.lora_B[key]
                        device = lora_B_module.weight.device
    
                        # Sample from the prior
                        sampled_weight = pyro.sample(
                            param_name,
                            dist.Normal(
                                lora_B_module.weight.detach().to(device),
                                (0.1 * torch.ones_like(lora_B_module.weight)).to(device)
                            ).to_event(lora_B_module.weight.dim())
                        )
    
                        # Assign the sampled weight to the module
                        # lora_B_module.weight = torch.nn.Parameter(sampled_weight)
                        lora_B_module.weight.data.copy_(sampled_weight)
    
            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            return loss
    
        # Use the modified model with sampled LoRA parameters
        return model_with_sampled_lora()


    # Set up SVI
    pyro.clear_param_store()
    optim = pyro.optim.Adam({"lr": learning_rate})
    elbo = TraceMeanField_ELBO()
    svi = SVI(bayesian_model, bayesian_guide, optim, loss=elbo)

    
    # Training loop for Task 1
    print(f"Training on Task 1...")

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        num_batches = 0
        for num_batches, batch in enumerate(train_loader, 1):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            loss = svi.step(input_ids, attention_mask, labels)
            total_loss += loss

            # Logging
            if num_batches % logging_steps == 0:
                avg_loss = total_loss / num_batches
                print(f"Epoch {epoch}, Step {num_batches}, Loss: {avg_loss}")

            # Evaluation
            if num_batches % eval_steps == 0:
                evaluate_model(model, eval_loader)

            # Save checkpoints
            if num_batches % save_steps == 0:
                save_trained_model(model, tokenizer, output_dir)

        avg_epoch_loss = total_loss / num_batches
        print(f"Task 1 Epoch {epoch} completed. Average Loss: {avg_epoch_loss}")

    # Save the final trained model after Task 1
    save_trained_model(model, tokenizer, output_dir)
    return model

    # After Task 1, compute FIM and save posterior means
    # fisher_info = compute_fisher_info(model, train_loader)
    # prev_posterior_means = get_variational_posterior_means()


In [22]:
print(os.getcwd())
os.chdir('/home/kakde2/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting/')
print(os.getcwd())

/home/kakde2/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting
/home/kakde2/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting


In [23]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

if __name__ == '__main__':
    model=run_lora_evcl_1(
        num_epochs=3,
        base_model_name="meta-llama/Meta-Llama-3-8B",
        batch_size=2,
        learning_rate=1e-5,
        logging_steps=100,
        eval_steps=200,
        save_steps=500,
        output_dir="finetuned-weights-LoRA-EVCL",
    )

Training on Task 1...
Epoch 0, Step 100, Loss: 852121.69
Epoch 0, Step 200, Loss: 852002.860625
Evaluation Loss: 9.9737
Task 1 Epoch 0 completed. Average Loss: 852012.5255
Epoch 1, Step 100, Loss: 851845.203125
Epoch 1, Step 200, Loss: 851937.7140625
Evaluation Loss: 15.9449
Task 1 Epoch 1 completed. Average Loss: 851954.0025
Epoch 2, Step 100, Loss: 851915.589375
Epoch 2, Step 200, Loss: 851983.643125
Evaluation Loss: 14.9926
Task 1 Epoch 2 completed. Average Loss: 851933.47325
Model and tokenizer saved to finetuned-weights-LoRA-EVCL


In [24]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B",
            device_map="auto",
            offload_folder='/home/kakde2/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting/llama_offload_evcl',
            load_in_8bit=True,
            llm_int8_enable_fp32_cpu_offload=True)

lora_model_path = "/home/kakde2/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting/finetuned-weights-LoRA-EVCL"
model = PeftModel.from_pretrained(base_model, lora_model_path)


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [26]:
for name, param in model.named_parameters():
    if 'lora' in name:
        param.requires_grad = True

# Verify that LoRA parameters require gradients
for name, param in model.named_parameters():
    if 'lora' in name:
        print(f"{name}: requires_grad={param.requires_grad}")

base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: requires_grad=True

In [None]:
from torch.amp import autocast, GradScaler
prev_fisher_info = None
prev_params = None
ewc_gamma = 1.0  

fisher_info = compute_fisher_info(
    model=model,
    data_loader=train_loader,
    prev_fisher_info=prev_fisher_info,
    ewc_gamma=ewc_gamma,
    num_epochs=1,  
    head_modules=None,  
    n_samples=None  
)


Starting Epoch 1/1
Processing batch 1
Completed batch 1
Processing batch 2
Completed batch 2
Processing batch 3
Completed batch 3
Processing batch 4
Completed batch 4
Processing batch 5
Completed batch 5
Processing batch 6
Completed batch 6
Processing batch 7
Completed batch 7
Processing batch 8
Completed batch 8
Processing batch 9
Completed batch 9
Processing batch 10
Completed batch 10
Processing batch 11
Completed batch 11
Processing batch 12
Completed batch 12
Processing batch 13
Completed batch 13
Processing batch 14
Completed batch 14
Processing batch 15
Completed batch 15
Processing batch 16
Completed batch 16
Processing batch 17
Completed batch 17
Processing batch 18
Completed batch 18
Processing batch 19
Completed batch 19
Processing batch 20
Completed batch 20
Processing batch 21
Completed batch 21
Processing batch 22
Completed batch 22
Processing batch 23
Completed batch 23
Processing batch 24
Completed batch 24
Processing batch 25
Completed batch 25
Processing batch 26
Comp

In [14]:
# def update_variational_approx(bnn, train_loader, curr_coreset, num_epochs, callback, ewc_lambda, fisher_info=None, prev_params=None):
#     if curr_coreset:
#         # Create a dataset from the coreset
#         print('Coreset is true')
#         coreset_input_ids = torch.stack([item[0] for item in curr_coreset])
#         coreset_labels = torch.stack([item[1] for item in curr_coreset])
#         coreset_dataset = TensorDataset(coreset_input_ids, coreset_labels)

#         # Combine coreset and current task data
#         combined_dataset = ConcatDataset([train_loader.dataset, coreset_dataset])
#         data_loader = DataLoader(combined_dataset, batch_size=train_loader.batch_size, shuffle=True)
#     else:
#         print('coreset not true')
#         data_loader = train_loader

#     optim = pyro.optim.Adam({"lr": 1e-5})  # Adjust learning rate as needed

#     with tyxe.poutine.local_reparameterization():
#         bnn.fit(data_loader, optim, num_epochs, device=DEVICE, callback=callback, ewc_lambda=ewc_lambda, fisher_info=fisher_info, prev_params=prev_params)


In [15]:
# from transformers import Trainer

# class EWCTrainer(Trainer):
#     def __init__(self, ewc_lambda, fisher_info, prev_params, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.ewc_lambda = ewc_lambda
#         self.fisher_info = fisher_info
#         self.prev_params = prev_params

#     def compute_loss(self, model, inputs, return_outputs=False):
#         # Standard loss
#         outputs = model(**inputs)
#         loss = outputs.loss

#         # EWC loss
#         if self.fisher_info is not None and self.prev_params is not None:
#             ewc_loss = 0
#             for name, param in model.named_parameters():
#                 if 'lora' in name and name in self.fisher_info:
#                     ewc_loss += (self.fisher_info[name] * (param - self.prev_params[name]) ** 2).sum()
#             ewc_loss = 0.5 * self.ewc_lambda * ewc_loss
#             loss += ewc_loss

#         return (loss, outputs) if return_outputs else loss


In [16]:
# import pyro.nn as pynn
# import torch.nn as nn
# class BayesianLoRAModule(pynn.PyroModule):
#     def __init__(self, model):
#         super().__init__()
#         self.model = model  # The base model remains unchanged

#         # Replace LoRA parameters with PyroSample
#         for name, module in self.model.named_modules():
#             if isinstance(module, nn.Linear) and hasattr(module, 'lora_A'):
#                 # Replace lora_A and lora_B with PyroSample
#                 lora_A_name = f"{name}.lora_A"
#                 lora_B_name = f"{name}.lora_B"

#                 # Get the existing parameters
#                 lora_A = getattr(module, 'lora_A')
#                 lora_B = getattr(module, 'lora_B')

#                 # Register PyroSample parameters
#                 setattr(module, 'lora_A', pynn.PyroSample(dist.Normal(lora_A.data, 1.0).to_event(2)))
#                 setattr(module, 'lora_B', pynn.PyroSample(dist.Normal(lora_B.data, 1.0).to_event(2)))

#     def forward(self, *args, **kwargs):
#         return self.model(*args, **kwargs)


In [17]:
# def get_bayesian_model_and_guide(model):
#     bayesian_model = BayesianLoRAModule(model)

#     # Define the guide
#     def guide(*args, **kwargs):
#         for name, module in bayesian_model.named_modules():
#             if isinstance(module, nn.Linear) and hasattr(module, 'lora_A'):
#                 # Define variational distributions for lora_A and lora_B
#                 lora_A_loc = pyro.param(f"{name}.lora_A_loc", torch.zeros_like(module.lora_A.data))
#                 lora_A_scale = pyro.param(f"{name}.lora_A_scale", torch.ones_like(module.lora_A.data), constraint=pyro.distributions.constraints.positive)
#                 lora_B_loc = pyro.param(f"{name}.lora_B_loc", torch.zeros_like(module.lora_B.data))
#                 lora_B_scale = pyro.param(f"{name}.lora_B_scale", torch.ones_like(module.lora_B.data), constraint=pyro.distributions.constraints.positive)
#                 pyro.sample(f"{name}.lora_A", dist.Normal(lora_A_loc, lora_A_scale).to_event(2))
#                 pyro.sample(f"{name}.lora_B", dist.Normal(lora_B_loc, lora_B_scale).to_event(2))
#     return bayesian_model, guide


In [None]:
import torch

# Check if CUDA is available
print("CUDA Available:", torch.cuda.is_available())

# Get the current device index
current_device = torch.cuda.current_device()
print("Current Device Index:", current_device)

# Get the name of the current device
device_name = torch.cuda.get_device_name(current_device)
print("Current Device Name:", device_name)

# Get the number of GPUs
num_gpus = torch.cuda.device_count()
print("Number of GPUs:", num_gpus)

# List all GPUs
for device_id in range(num_gpus):
    print(f"Device {device_id}: {torch.cuda.get_device_name(device_id)}")


In [None]:
def run_evcl(
    num_tasks: int = 2,  # Assuming tasks 1 (QA) and 2 (QG)
    num_epochs: int = 3,
    experiment_name: str = 'llama_evcl_superni',
    base_model_name: str = "meta-llama/Llama-2-7b-hf",
    lora_model_path: str = 'fine-tuned-llama-lora',
    batch_size: int = 8,
    coreset_size: int = 200,
    coreset_method: str = 'random',
    ewc_lambda: float = 100.0,
    ewc_gamma: float = 1.0,
    task_folder='QA_FineTuned'
):
    os.chdir(f'/home/pranav24/cs-546/Iterative-SSR-and-EVCL-Catastrophic-Forgetting/{task_folder}/finetuned-weights')
    
    print("Loading base model...")
    # Load the model already fine-tuned on the first task
    model = AutoModelForCausalLM.from_pretrained(base_model_name)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    print("Applying LoRA adapter...")
    model = PeftModel.from_pretrained(model, lora_model_path)
    model.to(DEVICE)
    model.eval()

    # Prepare the prior using the fine-tuned model
    prior = MLEPrior(model)
    obs = tyxe.likelihoods.Categorical(-1)
    guide = functools.partial(
        tyxe.guides.AutoNormal,
        init_scale=1e-4,
        init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(model, prefix="net")
    )

    # Initialize Bayesian model
    bnn = VariationalBNNWithEWC(model, prior, obs, guide)

    # Load the first task's data
    train_loader_task1, test_loader_task1 = get_data_loader_for_task1(tokenizer, batch_size)

    # Generate the initial coreset from the first task's data
    prev_coreset = update_coreset(prev_coreset=[], train_loader=train_loader_task1, coreset_size=coreset_size, selection_method=coreset_method)

    # Compute the initial Fisher Information Matrix and previous parameters
    prev_fisher_info = compute_fisher_info_llm(
        bnn, prev_fisher_info=None, data_loader=train_loader_task1, n_samples=5000, ewc_gamma=ewc_gamma
    )
    prev_params = {
        name: param.detach().clone()
        for name, param in bnn.named_parameters()
        if 'lora' in name
    }

    # Now proceed with tasks 2 and onwards
    # Prepare tasks 2 to num_tasks
    train_loaders, test_loaders = fetch_nlp_datasets(tokenizer, batch_size, num_tasks, start_task=2)

    for task_index, train_loader in enumerate(train_loaders, 2):  # Start from task_index=2
        print(f"Training on Task {task_index}...")

        # Update coreset
        if coreset_size > 0:
            curr_coreset = update_coreset(prev_coreset, train_loader, coreset_size, coreset_method)
            # curr_coreset = update_coreset(prev_coreset, train_loader_task1, coreset_size, coreset_method)
            # curr_coreset=prev_coreset
        else:
            curr_coreset = []

        # Training loop for current task
        def callback(epoch, step, loss):
            print(f"Epoch {epoch}, Step {step}, Loss: {loss}")

        # Fine-tune with variational inference and EWC
        update_variational_approx(
            bnn, train_loader, curr_coreset, num_epochs, callback, ewc_lambda,
            fisher_info=prev_fisher_info, prev_params=prev_params
        )

        # Compute Fisher Information Matrix for current task
        fisher_info = compute_fisher_info_llm(
            bnn, prev_fisher_info, train_loader, n_samples=5000, ewc_gamma=ewc_gamma
        )

        # Update prev_params and prev_fisher_info
        prev_params = {
            name: param.detach().clone()
            for name, param in bnn.named_parameters()
            if 'lora' in name
        }
        prev_fisher_info = fisher_info

        # Update prior with posterior from current task
        site_names = [site for site in tyxe.util.pyro_sample_sites(bnn) if 'lora' in site]
        params_to_update = tyxe.priors.DictPrior({
            site: list(bnn.net_guide.get_detached_distributions(site).values())[0]
            for site in site_names
        })
        bnn.update_prior(params_to_update)

        # Update prev_coreset
        prev_coreset = curr_coreset

        # Evaluate on all tasks up to current
        for j, test_loader in enumerate([test_loader_task1] + test_loaders[:task_index - 2], 1):
            print(f"Evaluating Task {j}...")
            total_loss = 0.0
            num_batches = 0
            for batch in test_loader:
                input_ids = batch["input_ids"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)
                with torch.no_grad():
                    outputs = bnn.net(input_ids, labels=labels)
                    loss = outputs.loss
                total_loss += loss.item()
                num_batches += 1
            avg_loss = total_loss / num_batches
            print(f"Task {j} Average Loss: {avg_loss:.4f}")

    print("Training completed.")


In [None]:
if __name__ == '__main__':
    run_evcl(
        num_tasks=2,  # QA and QG tasks
        num_epochs=3,
        experiment_name='llama_evcl_superni',
        base_model_name='meta-llama/Llama-2-7b-hf',
        lora_model_path='path/to/your/lora/model',
        batch_size=8,
        coreset_size=200,  # Adjust as needed
        ewc_lambda=100.0,
        ewc_gamma=1.0,
    )


How to Run the Process with SuperNI Dataset

Step 1: Environment Setup
(Same as previously described)

Step 2: Preparing the SuperNI Dataset
Install the datasets Library:
Ensure you have the datasets library installed:

pip install datasets
Inspect the SuperNI Dataset:
The SuperNI dataset can be loaded using:

from datasets import load_dataset

superni_dataset = load_dataset('super_nat_instruct', 'v1_1')
Note: Replace 'super_nat_instruct' and 'v1_1' with the correct dataset identifier if necessary.
Identify QA and QG Tasks:
SuperNI contains multiple tasks with task descriptions.
You need to identify the task IDs or names corresponding to QA and QG.
You can print out the tasks to find the ones you need:
for task in superni_dataset['train']['Task']:
    print(task)
Adjust the load_superni_task_dataset Function:
Modify the task_filter in load_superni_task_dataset to match the task identifiers for QA and QG.
For example:
if task_type == 'qa':
    task_ids = ['task_id_for_qa1', 'task_id_for_qa2']  # Replace with actual task IDs
    task_filter = lambda ex: ex['TaskID'] in task_ids
elif task_type == 'qg':
    task_ids = ['task_id_for_qg1', 'task_id_for_qg2']  # Replace with actual task IDs
    task_filter = lambda ex: ex['TaskID'] in task_ids
Adjust Data Preprocessing:
Ensure that the Input and Output fields are correctly used.
For some tasks, you might need to concatenate context and question.
Step 3: Running the Code
(Same as previously described)

Step 4: Monitoring and Evaluation
(Same as previously described)