In [None]:
import torch
import pyro
import tyxe

import random
import functools
import copy

import numpy as np

from pyro.infer import SVI, TraceMeanField_ELBO, Trace_ELBO

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset

from datasets import load_dataset  # Added to load SuperNI dataset

from typing import Optional, List

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
def compute_fisher_info_llm(bnn, prev_fisher_info, data_loader, n_samples=5000, ewc_gamma=1.0):
    est_fisher_info = {}
    # Only compute Fisher Information for LoRA parameters
    for name, param in bnn.named_parameters():
        if 'lora' in name:
            est_fisher_info[name] = torch.zeros_like(param)

    old_training_state = bnn.net.training
    bnn.net.eval()

    num_samples = 0
    for index, batch in enumerate(data_loader):
        if n_samples is not None and num_samples >= n_samples:
            break

        input_ids = batch['input_ids'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        outputs = bnn.net(input_ids, labels=labels)
        loss = outputs.loss
        bnn.net.zero_grad()
        loss.backward()

        for name, param in bnn.named_parameters():
            if 'lora' in name and param.grad is not None:
                est_fisher_info[name] += param.grad.detach() ** 2

        num_samples += input_ids.size(0)

    # Normalize the estimated Fisher information
    est_fisher_info = {n: p / num_samples for n, p in est_fisher_info.items()}

    if prev_fisher_info is not None:
        for name in est_fisher_info:
            if name in prev_fisher_info:
                est_fisher_info[name] += ewc_gamma * prev_fisher_info[name]

    bnn.net.train(old_training_state)

    return est_fisher_info


In [None]:
def fetch_nlp_datasets(tokenizer, batch_size, num_tasks, start_task=1):
    train_loaders = []
    test_loaders = []

    # Load the SuperNI dataset
    # You can specify the split and tasks you need
    superni_dataset = load_dataset('super_glue', 'ni')  # Adjust if necessary

    # Assuming tasks are numbered starting from 1
    for task_index in range(start_task, num_tasks + 1):
        if task_index == 1:
            # Load QA task from SuperNI
            train_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qa', split='train')
            test_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qa', split='validation')
        elif task_index == 2:
            # Load QG task from SuperNI
            train_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qg', split='train')
            test_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qg', split='validation')
        else:
            # Load additional tasks if needed
            pass

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        train_loaders.append(train_loader)
        test_loaders.append(test_loader)

    return train_loaders, test_loaders


In [None]:
def load_superni_task_dataset(superni_dataset, tokenizer, task_type='qa', split='train'):
    # Filter the dataset for the specific task type
    # SuperNI tasks are identified by their task names or IDs
    # For example, you can filter tasks that contain 'question answering' or 'question generation'

    # Example of filtering:
    if task_type == 'qa':
        task_filter = lambda ex: 'question answering' in ex['Task']
    elif task_type == 'qg':
        task_filter = lambda ex: 'question generation' in ex['Task']
    else:
        raise ValueError(f"Unsupported task type: {task_type}")

    dataset = superni_dataset[split].filter(task_filter)

    def preprocess_function(examples):
        # For SuperNI, inputs and outputs are in 'Input' and 'Output' fields
        inputs = examples['Input']
        targets = examples['Output']

        # Tokenize inputs and targets
        model_inputs = tokenizer(inputs, truncation=True, padding='max_length', max_length=512)
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(targets, truncation=True, padding='max_length', max_length=512)

        model_inputs['labels'] = labels['input_ids']
        return model_inputs

    dataset = dataset.map(preprocess_function, batched=True)
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

    return dataset


In [None]:
def get_data_loader_for_task1(tokenizer, batch_size):
    # Load the SuperNI dataset
    superni_dataset = load_dataset('super_glue', 'ni')  # Adjust if necessary

    # Load QA task
    train_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qa', split='train')
    test_dataset = load_superni_task_dataset(superni_dataset, tokenizer, task_type='qa', split='validation')

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


In [None]:
def run_evcl(
    num_tasks: int = 2,  # Assuming tasks 1 (QA) and 2 (QG)
    num_epochs: int = 3,
    experiment_name: str = 'llama_evcl_superni',
    base_model_name: str = "meta-llama/Llama-2-7b-hf",
    lora_model_path: str = 'path/to/your/lora/model',
    batch_size: int = 8,
    coreset_size: int = 200,
    coreset_method: str = 'random',
    ewc_lambda: float = 100.0,
    ewc_gamma: float = 1.0,
):
    print("Loading base model...")
    # Load the model already fine-tuned on the first task
    model = AutoModelForCausalLM.from_pretrained(base_model_name)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    print("Applying LoRA adapter...")
    model = PeftModel.from_pretrained(model, lora_model_path)
    model.to(DEVICE)
    model.eval()

    # Prepare the prior using the fine-tuned model
    prior = MLEPrior(model)
    obs = tyxe.likelihoods.Categorical(-1)
    guide = functools.partial(
        tyxe.guides.AutoNormal,
        init_scale=1e-4,
        init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(model, prefix="net")
    )

    # Initialize Bayesian model
    bnn = VariationalBNNWithEWC(model, prior, obs, guide)

    # Load the first task's data
    train_loader_task1, test_loader_task1 = get_data_loader_for_task1(tokenizer, batch_size)

    # Generate the initial coreset from the first task's data
    prev_coreset = update_coreset(prev_coreset=[], train_loader=train_loader_task1, coreset_size=coreset_size, selection_method=coreset_method)

    # Compute the initial Fisher Information Matrix and previous parameters
    prev_fisher_info = compute_fisher_info_llm(
        bnn, prev_fisher_info=None, data_loader=train_loader_task1, n_samples=5000, ewc_gamma=ewc_gamma
    )
    prev_params = {
        name: param.detach().clone()
        for name, param in bnn.named_parameters()
        if 'lora' in name
    }

    # Now proceed with tasks 2 and onwards
    # Prepare tasks 2 to num_tasks
    train_loaders, test_loaders = fetch_nlp_datasets(tokenizer, batch_size, num_tasks, start_task=2)

    for task_index, train_loader in enumerate(train_loaders, 2):  # Start from task_index=2
        print(f"Training on Task {task_index}...")

        # Update coreset
        if coreset_size > 0:
            curr_coreset = update_coreset(prev_coreset, train_loader, coreset_size, coreset_method)
        else:
            curr_coreset = []

        # Training loop for current task
        def callback(epoch, step, loss):
            print(f"Epoch {epoch}, Step {step}, Loss: {loss}")

        # Fine-tune with variational inference and EWC
        update_variational_approx(
            bnn, train_loader, curr_coreset, num_epochs, callback, ewc_lambda,
            fisher_info=prev_fisher_info, prev_params=prev_params
        )

        # Compute Fisher Information Matrix for current task
        fisher_info = compute_fisher_info_llm(
            bnn, prev_fisher_info, train_loader, n_samples=5000, ewc_gamma=ewc_gamma
        )

        # Update prev_params and prev_fisher_info
        prev_params = {
            name: param.detach().clone()
            for name, param in bnn.named_parameters()
            if 'lora' in name
        }
        prev_fisher_info = fisher_info

        # Update prior with posterior from current task
        site_names = [site for site in tyxe.util.pyro_sample_sites(bnn) if 'lora' in site]
        params_to_update = tyxe.priors.DictPrior({
            site: list(bnn.net_guide.get_detached_distributions(site).values())[0]
            for site in site_names
        })
        bnn.update_prior(params_to_update)

        # Update prev_coreset
        prev_coreset = curr_coreset

        # Evaluate on all tasks up to current
        for j, test_loader in enumerate([test_loader_task1] + test_loaders[:task_index - 2], 1):
            print(f"Evaluating Task {j}...")
            total_loss = 0.0
            num_batches = 0
            for batch in test_loader:
                input_ids = batch["input_ids"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)
                with torch.no_grad():
                    outputs = bnn.net(input_ids, labels=labels)
                    loss = outputs.loss
                total_loss += loss.item()
                num_batches += 1
            avg_loss = total_loss / num_batches
            print(f"Task {j} Average Loss: {avg_loss:.4f}")

    print("Training completed.")


In [None]:
if __name__ == '__main__':
    run_evcl(
        num_tasks=2,  # QA and QG tasks
        num_epochs=3,
        experiment_name='llama_evcl_superni',
        base_model_name='meta-llama/Llama-2-7b-hf',
        lora_model_path='path/to/your/lora/model',
        batch_size=8,
        coreset_size=200,  # Adjust as needed
        ewc_lambda=100.0,
        ewc_gamma=1.0,
    )


How to Run the Process with SuperNI Dataset

Step 1: Environment Setup
(Same as previously described)

Step 2: Preparing the SuperNI Dataset
Install the datasets Library:
Ensure you have the datasets library installed:

pip install datasets
Inspect the SuperNI Dataset:
The SuperNI dataset can be loaded using:

from datasets import load_dataset

superni_dataset = load_dataset('super_nat_instruct', 'v1_1')
Note: Replace 'super_nat_instruct' and 'v1_1' with the correct dataset identifier if necessary.
Identify QA and QG Tasks:
SuperNI contains multiple tasks with task descriptions.
You need to identify the task IDs or names corresponding to QA and QG.
You can print out the tasks to find the ones you need:
for task in superni_dataset['train']['Task']:
    print(task)
Adjust the load_superni_task_dataset Function:
Modify the task_filter in load_superni_task_dataset to match the task identifiers for QA and QG.
For example:
if task_type == 'qa':
    task_ids = ['task_id_for_qa1', 'task_id_for_qa2']  # Replace with actual task IDs
    task_filter = lambda ex: ex['TaskID'] in task_ids
elif task_type == 'qg':
    task_ids = ['task_id_for_qg1', 'task_id_for_qg2']  # Replace with actual task IDs
    task_filter = lambda ex: ex['TaskID'] in task_ids
Adjust Data Preprocessing:
Ensure that the Input and Output fields are correctly used.
For some tasks, you might need to concatenate context and question.
Step 3: Running the Code
(Same as previously described)

Step 4: Monitoring and Evaluation
(Same as previously described)