In [1]:
import torch
import pyro
import tyxe

import random
import functools
import copy

import numpy as np

from pyro.infer import SVI, TraceMeanField_ELBO, Trace_ELBO

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset

from datasets import load_dataset  # Added to load SuperNI dataset

from typing import Optional, List
from model.mle_prior import MLEPrior


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
import torch

print("CUDA Available:", torch.cuda.is_available())

current_device = torch.cuda.current_device()
print("Current Device Index:", current_device)

device_name = torch.cuda.get_device_name(current_device)
print("Current Device Name:", device_name)

num_gpus = torch.cuda.device_count()
print("Number of GPUs:", num_gpus)

for device_id in range(num_gpus):
    print(f"Device {device_id}: {torch.cuda.get_device_name(device_id)}")


CUDA Available: True
Current Device Index: 0
Current Device Name: NVIDIA A100-SXM4-80GB
Number of GPUs: 1
Device 0: NVIDIA A100-SXM4-80GB


### Task1 -QA LoRA+EVCL

In [3]:
def compute_fisher_info(
    model, 
    data_loader, 
    prev_fisher_info=None, 
    ewc_gamma=1.0, 
    num_epochs=1, 
    head_modules=None, 
    n_samples=None
):

    fisher = {}
    
    # Initialize Fisher matrix for LoRA parameters, excluding head modules if provided
    for name, param in model.named_parameters():
        if 'lora' in name and (head_modules is None or not any(name.startswith(head) for head in head_modules)):
            fisher[name] = torch.zeros_like(param).to(DEVICE)
    
    # Save the model's current training state and set to eval
    old_training_state = model.training
    model.eval()
    
    scaler = GradScaler(device='cuda')

    batch_count = 0

    for epoch in range(num_epochs):
        print(f"Starting Epoch {epoch + 1}/{num_epochs}")
        for i, batch in enumerate(data_loader):
            if n_samples is not None and batch_count >= n_samples:
                break

            print(f"Processing batch {batch_count + 1}")
            model.zero_grad()
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            try:
                # with autocast(device_type='cuda'):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                loss.backward()
            # scaler.scale(loss).backward()
            except RuntimeError as e:
                print(f"Error in batch {batch_count + 1}: {e}")
                break

            # Accumulate Fisher information for LoRA parameters
            for name, param in model.named_parameters():
                if 'lora' in name and param.grad is not None and (head_modules is None or not any(name.startswith(head) for head in head_modules)):
                    fisher[name] += param.grad.data ** 2

            print(f"Completed batch {batch_count + 1}")
            batch_count += 1

    # Normalize Fisher information by the number of processed batches or samples
    normalization_factor = batch_count if n_samples is None else min(batch_count, n_samples)
    for name in fisher:
        fisher[name] = fisher[name] / normalization_factor

    # Integrate previous Fisher information with EWC scaling
    if prev_fisher_info is not None:
        for name in fisher:
            if name in prev_fisher_info:
                fisher[name] += ewc_gamma * prev_fisher_info[name]

    # Restore the model's original training state
    model.train(old_training_state)
    
    return fisher

# Function to get variational posterior means
def get_variational_posterior_means(model):
    posterior_means = {}
    for name, module in model.named_modules():
        if hasattr(module, 'lora_A'):
            # print('yes')
            for key in module.lora_A:
                param_name = f"{name}.lora_A.{key}"
                loc_name = f"{param_name}_loc"
                if loc_name in pyro.get_param_store():
                    lora_A_loc = pyro.param(loc_name).detach().clone()
                    # Add '.weight' to the parameter name
                    posterior_means[f"{param_name}.weight"] = lora_A_loc
        if hasattr(module, 'lora_B'):
            # print('yes')
            for key in module.lora_B:
                param_name = f"{name}.lora_B.{key}"
                loc_name = f"{param_name}_loc"
                if loc_name in pyro.get_param_store():
                    lora_B_loc = pyro.param(loc_name).detach().clone()
                    # Add '.weight' to the parameter name
                    posterior_means[f"{param_name}.weight"] = lora_B_loc
    return posterior_means

In [4]:
from peft.tuners.lora import LoraLayer

In [5]:
import os
import torch
import zipfile
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling,BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from peft import PeftConfig, PeftModel
from accelerate import init_empty_weights
from datasets import Dataset
from huggingface_hub import login
from peft.tuners.lora import LoraLayer
from pyro.nn.module import to_pyro_module_
import bitsandbytes

def deterministic_lora_task():
    login("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")
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    base_model_repo_id = "meta-llama/Meta-Llama-3-8B"  
    adapter_model_dir = r"/home/pranav24/cs-546-project/SSR/Latest_Weights/QA_Weights/finetuned-weights/QA_Final"

    os.chdir(r'/home/pranav24/cs-546-project')
    
   
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,  
        device_map="auto",  
        offload_folder="offload",  
        offload_state_dict=True,  
    )
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_repo_id)
    tokenizer.pad_token = tokenizer.eos_token
    

    model = AutoModelForCausalLM.from_pretrained(
        base_model_repo_id,
        quantization_config=bnb_config,
        torch_dtype=torch.float16, 
    )
    model.config.reduction = "mean" 
    

    peft_config = PeftConfig.from_pretrained(adapter_model_dir)
    model = PeftModel.from_pretrained(model, adapter_model_dir, config=peft_config)
        
    for name, param in model.named_parameters():
        if 'lora' in name:
            print(name)
    return model,tokenizer


# def initialize_lora():
#     login("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")
#     # Set environment variable to manage memory fragmentation
#     os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    
     
#     # Specify directories and the path to the zip file
#     offload_dir = os.path.expanduser("llama_offload_evcl/")
     
#     os.makedirs(offload_dir, exist_ok=True)
     
#     # Extract only the specified JSON file from the zip archive
#     os.chdir('/home/pranav24/cs-546-project/SSR/Latest_Weights/QA_Weights')
#     target_file = "task024_cosmosqa_answer_generation.json"
     
#     # Load tokenizer from Hugging Face
#     tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
#     tokenizer.pad_token = tokenizer.eos_token


#     # Load the model with accelerate's offloading and device map auto-setup
#     with init_empty_weights():
#         model = AutoModelForCausalLM.from_pretrained(
#             "meta-llama/Meta-Llama-3-8B",
#             device_map="auto",
#             # max_memory=max_memory,
#             offload_folder=offload_dir,
#             load_in_8bit=True,
#             llm_int8_enable_fp32_cpu_offload=True
#         )
     
#     # Configure LoRA with reduced rank
#     lora_config = LoraConfig(
#         r=8,
#         lora_alpha=16,
#         lora_dropout=0.1,
#         bias="none",
#         task_type="CAUSAL_LM",
#     )
#     model = get_peft_model(model, lora_config)

#     #printing the trainable parameters
#     model.print_trainable_parameters()

#     # for name, param in model.named_parameters():
#     #     if 'lora' in name:
#     #         print(name)

#     return model, tokenizer

    

In [6]:
print("Loading base model...")
model,tokenizer=deterministic_lora_task()

Unused kwargs: ['device_map', 'offload_folder', 'offload_state_dict']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading base model...


`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight
base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight
base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight
base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight
base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight
base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight
base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight
base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight
base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight
base_m

In [7]:
import os
import torch
import zipfile
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model
from accelerate import init_empty_weights
from datasets import Dataset
from huggingface_hub import login
from peft.tuners.lora import LoraLayer
from pyro.nn.module import to_pyro_module_
os.chdir('/home/pranav24/cs-546-project/SSR/Latest_Weights/QA_Weights')
target_file = "task024_cosmosqa_answer_generation.json"

with open(target_file, 'r', encoding='utf-8-sig') as f:
    json_data = json.load(f)

instances = json_data['Instances'][0:2500]
input_texts = [str(instance['input']) for instance in instances]
output_texts = [str(instance['output'][0]) if instance['output'] else "" for instance in instances]

# Create Hugging Face Dataset
ds = Dataset.from_dict({'input': input_texts, 'output': output_texts})

# Tokenize the dataset
# def tokenize_function(examples):
#     return tokenizer(examples["input"], examples["output"], truncation=True, padding="max_length", max_length=1024)
def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["input"],
        truncation=True,
        padding="max_length",
        max_length=512
    )
    labels = tokenizer(
        examples["output"],
        truncation=True,
        padding="max_length",
        max_length=512,
    )["input_ids"]
    model_inputs["labels"] = labels
    return model_inputs

# Apply tokenization and set format
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["input", "output"])
tokenized_datasets.set_format("torch")

# Split dataset into train and eval
train_size = int(0.8 * len(tokenized_datasets))
train_dataset = tokenized_datasets.select(range(train_size))
eval_dataset = tokenized_datasets.select(range(train_size, len(tokenized_datasets)))

# Create DataLoaders
batch_size = 8  
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

# Define data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

In [8]:
def save_trained_model(model, tokenizer, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    model.save_pretrained(output_dir)

    tokenizer.save_pretrained(output_dir)
    print(f"Model and tokenizer saved to {output_dir}")

In [9]:
def evaluate_model(model, eval_loader):
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    num_batches = 0
    sampled_weights_log = []  # To store sampled weights for verification

    with torch.no_grad():
        for batch in eval_loader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            with torch.cuda.amp.autocast():
                # Log sampled weights for LoRA layers during the forward pass
                for name, module in model.named_modules():
                    if hasattr(module, "lora_A"):
                        for key in module.lora_A:
                            loc = pyro.param(f"{name}.lora_A.{key}_loc")
                            scale = pyro.param(f"{name}.lora_A.{key}_scale")
                            sampled_weight = pyro.sample(
                                f"{name}.lora_A.{key}",
                                dist.Normal(loc, scale).to_event(loc.dim())
                            )
                            # Log sampled weight for debugging
                            sampled_weights_log.append(
                                (name, key, sampled_weight.clone().cpu().numpy())
                            )
                            # Ensure the sampled weight is used in the model
                            module.lora_A[key].weight.data.copy_(sampled_weight)

                    if hasattr(module, "lora_B"):
                        for key in module.lora_B:
                            loc = pyro.param(f"{name}.lora_B.{key}_loc")
                            scale = pyro.param(f"{name}.lora_B.{key}_scale")
                            sampled_weight = pyro.sample(
                                f"{name}.lora_B.{key}",
                                dist.Normal(loc, scale).to_event(loc.dim())
                            )
                            # Log sampled weight for debugging
                            sampled_weights_log.append(
                                (name, key, sampled_weight.clone().cpu().numpy())
                            )
                            # Ensure the sampled weight is used in the model
                            module.lora_B[key].weight.data.copy_(sampled_weight)

                # Perform forward pass
                outputs = model(input_ids, labels=labels, attention_mask=attention_mask)
                loss = outputs.loss
                total_loss += loss.item()
                num_batches += 1

    avg_loss = total_loss / num_batches

    # Log the sampled weights (optional, for debugging)
    print("Sampled Weights Log:")
    for layer_name, key, weight in sampled_weights_log[:5]:  # Show only the first 5 entries
        print(f"Layer: {layer_name}, Key: {key}, Sampled Weight (snippet): {weight.flatten()[:5]}")

    print(f"Evaluation Loss: {avg_loss:.4f}")
    return avg_loss


In [10]:
# evaluate_model(model,eval_loader)

In [11]:
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.optim import AdamW
import torch.cuda.amp as amp
from transformers import get_scheduler
from pyro.optim import ExponentialLR
evaluation_loss=[]


def run_lora_evcl_1(
    train_loader,
    eval_loader,
    num_epochs: int = 100,
    model: str = "meta-llama/Meta-Llama-3-8B",
    batch_size: int = 2,
    learning_rate: float = 1e-5,
    logging_steps: int = 100,
    eval_steps: int = 200,
    save_steps: int = 500,
    output_dir: str = "finetuned-weights-LoRA-EVCL",
    load_pyro: bool = False,
    best_output_dir="finetuned-weights-LoRA-EVCL-Final-Task1_VCL_best"
):


    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
        else:
            param.requires_grad = False  # Freeze non-LoRA parameters
    model.print_trainable_parameters()

    def bayesian_guide(input_ids, attention_mask, labels, epoch, warmup_epochs=10, min_scale_factor=0.1):

        annealing_factor = max(1.0 - (epoch / warmup_epochs), min_scale_factor)
        
        # Define variational distributions over the LoRA parameters
        for name, module in model.named_modules():
            if hasattr(module, 'lora_A'):
                for key in module.lora_A:
                    param_name = f"{name}.lora_A.{key}"
                    lora_A_param = module.lora_A[key].weight
                    device = lora_A_param.device

                    # Ensure initial values are leaf tensors with requires_grad=True
                    loc_init = lora_A_param.detach().clone().to(device).requires_grad_()
                    scale_init = (0.01 * torch.ones_like(lora_A_param)).to(device).requires_grad_()

                    loc = pyro.param(
                        f"{param_name}_loc",
                        loc_init
                    )
                    scale = pyro.param(
                        f"{param_name}_scale",
                        scale_init,
                        constraint=dist.constraints.positive
                    )
                    
                    adjusted_scale = scale * annealing_factor
                    
                    pyro.sample(
                        param_name,
                        dist.Normal(loc, adjusted_scale).to_event(lora_A_param.dim())
                    )
            if hasattr(module, 'lora_B'):
                for key in module.lora_B:
                    param_name = f"{name}.lora_B.{key}"
                    lora_B_param = module.lora_B[key].weight
                    device = lora_B_param.device

                    # Ensure initial values are leaf tensors with requires_grad=True
                    loc_init = lora_B_param.detach().clone().to(device).requires_grad_()
                    scale_init = (0.01 * torch.ones_like(lora_B_param)).to(device).requires_grad_()

                    loc = pyro.param(
                        f"{param_name}_loc",
                        loc_init
                    )
                    scale = pyro.param(
                        f"{param_name}_scale",
                        scale_init,
                        constraint=dist.constraints.positive
                    )
                    
                    adjusted_scale = scale * annealing_factor
                    
                    pyro.sample(
                        param_name,
                        dist.Normal(loc, adjusted_scale).to_event(lora_B_param.dim())
                    )
                    
    def bayesian_model(input_ids, attention_mask, labels):
        # Define a function to sample and substitute LoRA parameters
        def model_with_sampled_lora():
            # Sample LoRA parameters and set them in the model
            for name, module in model.named_modules():
                if hasattr(module, 'lora_A'):
                    for key in module.lora_A:
                        param_name = f"{name}.lora_A.{key}"
                        lora_A_module = module.lora_A[key]
                        device = lora_A_module.weight.device
    
                        # Sample from the prior
                        sampled_weight = pyro.sample(
                            param_name,
                            dist.Normal(
                                lora_A_module.weight.detach().to(device),
                                (0.1 * torch.ones_like(lora_A_module.weight)).to(device)
                            ).to_event(lora_A_module.weight.dim())
                        )
    
                        # Assign the sampled weight to the module
                        with torch.no_grad():
                            lora_A_module.weight.copy_(sampled_weight)
    
                if hasattr(module, 'lora_B'):
                    for key in module.lora_B:
                        param_name = f"{name}.lora_B.{key}"
                        lora_B_module = module.lora_B[key]
                        device = lora_B_module.weight.device
    
                        # Sample from the prior
                        sampled_weight = pyro.sample(
                            param_name,
                            dist.Normal(
                                lora_B_module.weight.detach().to(device),
                                (0.1 * torch.ones_like(lora_B_module.weight)).to(device)
                            ).to_event(lora_B_module.weight.dim())
                        )
    
                        # Assign the sampled weight to the module
                        with torch.no_grad():
                            lora_B_module.weight.copy_(sampled_weight)
    
            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            return loss
    
        # Use the modified model with sampled LoRA parameters
        return model_with_sampled_lora()


    # Set up SVI
    if load_pyro:
        print('using previous pyro params')
        pyro.get_param_store().load('pyro_param_store_task1.pt')
    else:
        print('not using previous pyro params')
        pyro.clear_param_store()
        
    optim = pyro.optim.Adam({"lr": learning_rate})
    optim = pyro.optim.PyroOptim(AdamW, {"lr": learning_rate, "weight_decay": 1e-5})
  
    scheduler = ExponentialLR({'optimizer': AdamW, 'optim_args': {'lr': learning_rate}, 'gamma': 0.1})
    elbo = TraceMeanField_ELBO()
    # svi = SVI(bayesian_model, bayesian_guide, scheduler, loss=elbo)
    # svi = SVI(bayesian_model, lambda *args, **kwargs: bayesian_guide(*args, **kwargs, epoch=epoch), scheduler, loss=elbo)

    print(f"Training on Task 1...")
    max_wait=20
    best_eval_loss = float('inf')
    no_improvement = 0
    
    for epoch in range(num_epochs):
        svi = SVI(bayesian_model, lambda *args, **kwargs: bayesian_guide(*args, **kwargs, epoch=epoch), scheduler, loss=elbo)
        model.train()
        total_loss = 0.0
        num_batches = 0
        for num_batches, batch in enumerate(train_loader, 1):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            if epoch==0 and num_batches==1:
                generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=512,  # Adjust as needed
                num_return_sequences=1,
            )
                
                batch_predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
                print(batch_predictions)
    
                data = {
                            "batch_predictions": batch_predictions,
                        }


                with open(f"/home/pranav24/cs-546-project/Testing/predictions_EVCL_1_epoch_{epoch}_{num_batches}.json", "w") as json_file:
                    json.dump(data, json_file, indent=4)

            # generated_ids = model.generate(
            #     input_ids=input_ids,
            #     attention_mask=attention_mask,
            #     max_new_tokens=512,  # Adjust as needed
            #     num_return_sequences=1,
            # )

            # batch_predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            # print(batch_predictions)

            # data = {
            #             "batch_predictions": batch_predictions,
            #         }


            # with open(f"/home/pranav24/cs-546-project/Testing/predictions_EVCL_1_{num_batches}.json", "w") as json_file:
            #     json.dump(data, json_file, indent=4)
            
            loss = svi.step(input_ids, attention_mask, labels)
            total_loss += loss

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            scheduler.step()


            # Logging
            if num_batches % logging_steps == 0:
                avg_loss = total_loss / num_batches
                print(f"Epoch {epoch}, Step {num_batches}, Loss: {avg_loss}")

            # Evaluation
            if num_batches % eval_steps == 0:
                eval_loss=evaluate_model(model, eval_loader)
                evaluation_loss.append(eval_loss)
                


        avg_epoch_loss = total_loss / num_batches
        print(f"Task 1 Epoch {epoch} completed. Average Loss: {avg_epoch_loss}")

        if epoch%10 ==0:
            save_trained_model(model, tokenizer, output_dir)
            pyro.get_param_store().save('pyro_param_store_task1_vcl.pt')

        if epoch%5==0:
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=512,  # Adjust as needed
                num_return_sequences=1,
            )
            batch_predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            print(batch_predictions)

            data = {
                        "batch_predictions": batch_predictions,
                    }


            with open(f"/home/pranav24/cs-546-project/Testing/predictions_EVCL_1_epoch_{epoch}_{num_batches}.json", "w") as json_file:
                json.dump(data, json_file, indent=4)
            
                
        if eval_loss<best_eval_loss:
            best_eval_loss=eval_loss
            no_improvement=0
            save_trained_model(model, tokenizer, best_output_dir)
            pyro.get_param_store().save('pyro_param_store_task1_vcl_best.pt')
        else:
            no_improvement+=1

        if no_improvement>=max_wait:
            print(f'early stopping at epoch: {epoch}')
            break
            
    
    save_trained_model(model, tokenizer, output_dir)
    pyro.get_param_store().save('pyro_param_store_task1_vcl.pt') 
    
    return model


In [12]:
print(os.getcwd())
os.chdir('/home/pranav24/cs-546-project/')
print(os.getcwd())

/home/pranav24/cs-546-project/SSR/Latest_Weights/QA_Weights
/home/pranav24/cs-546-project


In [13]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

if __name__ == '__main__':
    model=run_lora_evcl_1(
        train_loader=train_loader,
        eval_loader=eval_loader,
        num_epochs=10,
        model=model,
        batch_size=8,
        # learning_rate=1e-5,
        learning_rate=2e-4,
        logging_steps=100,
        eval_steps=200,
        save_steps=500,
        output_dir="finetuned-weights-LoRA-EVCL-Test-Task1_VCL",
        load_pyro=False,
        best_output_dir="finetuned-weights-LoRA-EVCL-Test-Task1_VCL_best"
    )

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424
not using previous pyro params
Training on Task 1...
["Context: I was bored, so I thought I 'd take apart a guitar humbucker I had lying around, and test its magnetic polarity. To do this, I took a fridge magnet, and test which end would repel it. Turned out, there was no end of the pickup magnet that repelled the fridge one. I tried everywhere on all six sides, and I didn't even notice any difference in the degree of attraction. \nQuestion: What may happen after they noticed the magnet would not repel?The magnet would be thrown out.They may have noticed the magnet was faulty. \nFact: The narrator is taking apart a guitar humbucker.They are taking apart the guitar humbucker to test the magnetic polarity.. \nReason: They are testing the magnetic polarity because they are bored. They are bored so they are testing the magnetic polarity.They are bored so they are testing the magnetic polarity. \nThe narrator is 

  with torch.cuda.amp.autocast():


Sampled Weights Log:
Layer: base_model.model.model.layers.0.self_attn.q_proj, Key: default, Sampled Weight (snippet): [ 0.00744212  0.01787096 -0.00340345  0.02338447 -0.00342143]
Layer: base_model.model.model.layers.0.self_attn.q_proj, Key: default, Sampled Weight (snippet): [-0.01358767 -0.01266908 -0.00178339 -0.00520429  0.00176825]
Layer: base_model.model.model.layers.0.self_attn.v_proj, Key: default, Sampled Weight (snippet): [-0.00067367 -0.00840912  0.00555926 -0.01680095  0.00778805]
Layer: base_model.model.model.layers.0.self_attn.v_proj, Key: default, Sampled Weight (snippet): [ 0.00232528  0.01133147  0.00329883  0.00396356 -0.00098567]
Layer: base_model.model.model.layers.1.self_attn.q_proj, Key: default, Sampled Weight (snippet): [-0.02270155 -0.00788337 -0.03990851 -0.00589622 -0.0139312 ]
Evaluation Loss: 11.2919
Task 1 Epoch 0 completed. Average Loss: 6176286.134


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Model and tokenizer saved to finetuned-weights-LoRA-EVCL-Test-Task1_VCL
["Context: I have sustained many injuries from the outside world. My house isn't much safer. I run into walls, i trip over toys. I once tripped over one of Brady's toys and slammed into the doors that hide my washer and dryer. \nQuestion: Why isn't my house safer?.\n\n\nThe outside world is more dangerous than my house. \n", "Context: Nate and Melissa are really cool, I've hung out with Nate before but never Melissa, but she was pretty nice. I usually don't get along with girls. Kristen and Diah were kind of empty - headed, but sweet overall, so I can't complain about them other than that they're too nice, but genuinely so. \nQuestion: Why was the narrator hanging out with Nate and Melissa.?The narrator was hanging out with Nate and Melissa because they were really cool. \nFact: The narrator was hanging out with Nate and Melissa because they were really cool. \n", 'Context: I woke up yesterday feeling a little risk

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


Task 1 Epoch 5 completed. Average Loss: 8512868.408
['Context: Two things happened today in Beijing. First off, incoming journalists were amazed to find China had successfully lifted the brown haze in city. Skies were crystal blue and the air felt noticeably lighter. \nQuestion: Why did the sky appear clearer?nThey were able to clean the air.The air was cleaned. \nFact: The air was cleaner.The air was cleaner because they cleaned it.The air was cleaner because they cleaned it. \nThe air was cleaner because they cleaned it. \nThe air was cleaner because they cleaned it. \nThe air was cleaner because they cleaned it. \nThey cleaned the air. \nThey cleaned the air. \nThey cleaned the air. \n', "Context: I'm not really sure why, but it's just good to see her and to know she's doing well. I really dislike having to avoid someone or keep some sort of hatred in my heart. It's just too hard for me to do. \nQuestion: Why would I have hatred in my heart towards her?; I don't like her...It's just

In [58]:
from transformers import AutoModelForCausalLM, AutoTokenizer

os.chdir(r'/home/pranav24/cs-546-project')
pyro.get_param_store().load('pyro_param_store_task1_vcl_best.pt')
login("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")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

base_model_repo_id = "meta-llama/Meta-Llama-3-8B"  
adapter_model_dir = r"/home/pranav24/cs-546-project/finetuned-weights-LoRA-EVCL-Final-Task1_VCL_best"



bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,  
    device_map="auto",  
    offload_folder="offload",  
    offload_state_dict=True,  
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_repo_id)
tokenizer.pad_token = tokenizer.eos_token


model = AutoModelForCausalLM.from_pretrained(
    base_model_repo_id,
    quantization_config=bnb_config,
    torch_dtype=torch.float16, 
)
model.config.reduction = "mean" 


peft_config = PeftConfig.from_pretrained(adapter_model_dir)
model = PeftModel.from_pretrained(model, adapter_model_dir, config=peft_config)


Unused kwargs: ['device_map', 'offload_folder', 'offload_state_dict']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [59]:
for name, param in model.named_parameters():
    if 'lora' in name:
        param.requires_grad = True

for name, param in model.named_parameters():
    if 'lora' in name:
        print(f"{name}: requires_grad={param.requires_grad}")

base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: requires_grad=True
base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: requires_grad=True
base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: requires_grad=True

In [60]:
model.print_trainable_parameters()

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


In [18]:
from torch.amp import autocast, GradScaler
prev_fisher_info = None
prev_params = None
ewc_gamma = 1.0  

fisher_info = compute_fisher_info(
    model=model,
    data_loader=train_loader,
    prev_fisher_info=prev_fisher_info,
    ewc_gamma=ewc_gamma,
    num_epochs=1,  
    head_modules=None,  
    n_samples=None  
)


Starting Epoch 1/1
Processing batch 1
Completed batch 1
Processing batch 2
Completed batch 2
Processing batch 3
Completed batch 3
Processing batch 4
Completed batch 4
Processing batch 5
Completed batch 5
Processing batch 6
Completed batch 6
Processing batch 7
Completed batch 7
Processing batch 8
Completed batch 8
Processing batch 9
Completed batch 9
Processing batch 10
Completed batch 10
Processing batch 11
Completed batch 11
Processing batch 12
Completed batch 12
Processing batch 13
Completed batch 13
Processing batch 14
Completed batch 14
Processing batch 15
Completed batch 15
Processing batch 16
Completed batch 16
Processing batch 17
Completed batch 17
Processing batch 18
Completed batch 18
Processing batch 19
Completed batch 19
Processing batch 20
Completed batch 20
Processing batch 21
Completed batch 21
Processing batch 22
Completed batch 22
Processing batch 23
Completed batch 23
Processing batch 24
Completed batch 24
Processing batch 25
Completed batch 25
Processing batch 26
Comp

In [61]:
# import pickle

# with open('fisher_info_task1.pkl', 'wb') as f:
#     pickle.dump(fisher_info, f)
# print("Fisher Information saved successfully.")

Fisher Information saved successfully.


In [16]:
import pickle
os.chdir('/home/pranav24/cs-546-project/')
with open('fisher_info_task1.pkl', 'rb') as f:
    fisher_info = pickle.load(f)
print("Fisher Information loaded successfully.")

Fisher Information loaded successfully.


In [17]:
for name, fisher_matrix in fisher_info.items():
    print(f"Layer: {name}, Fisher Info Mean: {fisher_matrix.mean().item()}")

Layer: base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight, Fisher Info Mean: 5.602710371022113e-05
Layer: base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight, Fisher Info Mean: 7.817161531420425e-05
Layer: base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight, Fisher Info Mean: 0.007847669534385204
Layer: base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight, Fisher Info Mean: 0.025193164125084877
Layer: base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight, Fisher Info Mean: 8.951361451181583e-06
Layer: base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight, Fisher Info Mean: 1.1536309102666564e-05
Layer: base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight, Fisher Info Mean: 0.005907014012336731
Layer: base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight, Fisher Info Mean: 0.01295658852905035
Layer: base_model.model.model.layers.2.self_attn.q_p

In [15]:
# for name, fisher_matrix in fisher_info_check.items():
#     print(f"Layer: {name}, Fisher Info Mean: {fisher_matrix.mean().item()}")

In [66]:
# prev_posterior_means = get_variational_posterior_means(model)
# torch.save(prev_posterior_means, f'posterior_means_task_{1}.pt')

In [20]:
prev_posterior_means = torch.load('posterior_means_task_1.pt')

  prev_posterior_means = torch.load('posterior_means_task_1.pt')


In [67]:
prev_posterior_means

{'base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight': tensor([[ 0.0089,  0.0050,  0.0051,  ..., -0.0123, -0.0097, -0.0027],
         [ 0.0048,  0.0048,  0.0027,  ..., -0.0147,  0.0005, -0.0033],
         [-0.0054,  0.0072,  0.0114,  ..., -0.0055,  0.0123, -0.0172],
         ...,
         [-0.0124,  0.0095, -0.0006,  ...,  0.0063,  0.0076, -0.0081],
         [-0.0139, -0.0162,  0.0034,  ...,  0.0069,  0.0097, -0.0077],
         [-0.0072,  0.0038, -0.0073,  ..., -0.0069, -0.0005, -0.0084]],
        device='cuda:0'),
 'base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight': tensor([[-1.5358e-03, -8.9430e-04,  5.7832e-04,  ...,  3.9116e-04,
           2.1487e-03, -6.4525e-04],
         [-2.2344e-03, -3.1120e-03, -4.8351e-04,  ..., -1.0736e-03,
           3.9612e-03, -1.8255e-03],
         [-9.6764e-04, -3.1381e-03, -1.3185e-03,  ..., -4.2025e-04,
           4.3880e-03, -1.0377e-06],
         ...,
         [-1.6329e-02,  1.4821e-02, -1.4626e-02,  ..., -

### Task 2: QA+QG EVCL

In [29]:
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.optim import AdamW
import torch.cuda.amp as amp
from transformers import get_scheduler
from pyro.optim import ExponentialLR
evaluation_loss=[]

def run_lora_evcl_2(
    combined_loader,  
    eval_loader,
    num_epochs: int = 100,
    batch_size: int = 2,
    learning_rate: float = 1e-5,
    logging_steps: int = 100,
    eval_steps: int = 200,
    save_steps: int = 500,
    output_dir: str = "finetuned-weights-LoRA-EVCL-Final-Task2",
    load_pyro: bool = False,
    best_output_dir="finetuned-weights-LoRA-EVCL-Final-Task2_EVCL_best",
    prev_fisher_info: dict = None,            
    prev_posterior_means: dict = None,        
    ewc_lambda: float = 0.0,                  
    synthetic_data_loader=None,               
    tokenizer=None,
    model=None
):
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(DEVICE)

    # Ensure all parameters require gradients
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
        else:
            param.requires_grad = False  # Freeze non-LoRA parameters

    def bayesian_guide(input_ids, attention_mask, labels):
        # Define variational distributions over the LoRA parameters
        for name, module in model.named_modules():
            if hasattr(module, 'lora_A'):
                for key in module.lora_A:
                    param_name = f"{name}.lora_A.{key}"
                    lora_A_param = module.lora_A[key].weight
                    device = lora_A_param.device

                    # Ensure initial values are leaf tensors with requires_grad=True
                    loc_init = lora_A_param.detach().clone().to(device).requires_grad_()
                    scale_init = (0.1 * torch.ones_like(lora_A_param)).to(device).requires_grad_()

                    loc = pyro.param(
                        f"{param_name}_loc",
                        loc_init
                    )
                    scale = pyro.param(
                        f"{param_name}_scale",
                        scale_init,
                        constraint=dist.constraints.positive
                    )
                    pyro.sample(
                        param_name,
                        dist.Normal(loc, scale).to_event(lora_A_param.dim())
                    )
            if hasattr(module, 'lora_B'):
                for key in module.lora_B:
                    param_name = f"{name}.lora_B.{key}"
                    lora_B_param = module.lora_B[key].weight
                    device = lora_B_param.device

                    # Ensure initial values are leaf tensors with requires_grad=True
                    loc_init = lora_B_param.detach().clone().to(device).requires_grad_()
                    scale_init = (0.1 * torch.ones_like(lora_B_param)).to(device).requires_grad_()

                    loc = pyro.param(
                        f"{param_name}_loc",
                        loc_init
                    )
                    scale = pyro.param(
                        f"{param_name}_scale",
                        scale_init,
                        constraint=dist.constraints.positive
                    )
                    pyro.sample(
                        param_name,
                        dist.Normal(loc, scale).to_event(lora_B_param.dim())
                    )
                        
    def bayesian_model(input_ids, attention_mask, labels):
        # pyro.module("model", model)  # Removed

        # Define a function to sample and substitute LoRA parameters
        def model_with_sampled_lora():
            # Sample LoRA parameters and set them in the model
            for name, module in model.named_modules():
                if hasattr(module, 'lora_A'):
                    for key in module.lora_A:
                        param_name = f"{name}.lora_A.{key}"
                        lora_A_module = module.lora_A[key]
                        device = lora_A_module.weight.device

                        # Use posterior mean from Task 1 as prior mean
                        prior_mean = prev_posterior_means.get(param_name, lora_A_module.weight.detach().clone()).to(device)
                        prior_std = (0.1 * torch.ones_like(lora_A_module.weight)).to(device)

                        # Sample from the prior
                        sampled_weight = pyro.sample(
                            param_name,
                            dist.Normal(
                                prior_mean,
                                prior_std
                            ).to_event(lora_A_module.weight.dim())
                        )

                        # Assign the sampled weight to the module
                        with torch.no_grad():
                            lora_A_module.weight.copy_(sampled_weight)

                if hasattr(module, 'lora_B'):
                    for key in module.lora_B:
                        param_name = f"{name}.lora_B.{key}"
                        lora_B_module = module.lora_B[key]
                        device = lora_B_module.weight.device

                        # Use posterior mean from Task 1 as prior mean
                        prior_mean = prev_posterior_means.get(param_name, lora_B_module.weight.detach().clone()).to(device)
                        prior_std = (0.1 * torch.ones_like(lora_B_module.weight)).to(device)

                        # Sample from the prior
                        sampled_weight = pyro.sample(
                            param_name,
                            dist.Normal(
                                prior_mean,
                                prior_std
                            ).to_event(lora_B_module.weight.dim())
                        )

                        # Assign the sampled weight to the module
                        with torch.no_grad():
                            lora_B_module.weight.copy_(sampled_weight)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            # Add EWC penalty if previous Fisher info and posterior means are provided
            if prev_fisher_info is not None and prev_posterior_means is not None and ewc_lambda > 0.0:
                ewc_penalty = 0.0
                for name, param in model.named_parameters():
                    if 'lora' in name and name in prev_fisher_info:
                        fisher = prev_fisher_info[name].to(DEVICE)
                        prev_mean = prev_posterior_means[name].to(DEVICE)
                        ewc_penalty += (fisher * (param - prev_mean) ** 2).sum()
                        # print('penalty of ewc')
                        # print(ewc_penalty)
                loss += ewc_lambda * ewc_penalty

            return loss

        # Use the modified model with sampled LoRA parameters
        return model_with_sampled_lora()

    # Set up SVI
    if load_pyro:
        print('using previous pyro params')
        pyro.get_param_store().load('pyro_param_store_task1_vcl_best.pt')
    else:
        print('not using previous pyro params')
        pyro.clear_param_store()
        
    optim = pyro.optim.PyroOptim(AdamW, {"lr": learning_rate, "weight_decay": 1e-5})
  
    scheduler = ExponentialLR({'optimizer': AdamW, 'optim_args': {'lr': learning_rate}, 'gamma': 0.1})
    elbo = TraceMeanField_ELBO()
    svi = SVI(bayesian_model, bayesian_guide, scheduler, loss=elbo)
    evaluation_loss=[]

    # optim = pyro.optim.Adam({"lr": learning_rate})
    # elbo = TraceMeanField_ELBO()
    # svi = SVI(bayesian_model, bayesian_guide, optim, loss=elbo)

    # Training loop
    max_wait=10
    best_eval_loss = float('inf')
    no_improvement = 0
    print("Training on new task with EWC and synthetic data from previous task...")

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        num_batches = 0
        for num_batches, batch in enumerate(combined_loader, 1):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            loss = svi.step(input_ids, attention_mask, labels)
            total_loss += loss

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            scheduler.step()

            # Logging
            if num_batches % logging_steps == 0:
                avg_loss = total_loss / num_batches
                print(f"Epoch {epoch + 1}, Step {num_batches}, Loss: {avg_loss}")

            # Evaluation
            if num_batches % eval_steps == 0:
                eval_loss=evaluate_model(model, eval_loader)
                evaluation_loss.append(eval_loss)

            # Save checkpoints
            # if num_batches % save_steps == 0:
            #     save_trained_model(model, tokenizer, output_dir)

        avg_epoch_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1} completed. Average Loss: {avg_epoch_loss}")


        if eval_loss<best_eval_loss:
            best_eval_loss=eval_loss
            no_improvement=0
            save_trained_model(model, tokenizer, best_output_dir)
            pyro.get_param_store().save('pyro_param_store_task-test2_vcl_best.pt')
        else:
            no_improvement+=1

        if no_improvement>=max_wait and epoch>=50:
            print(f'early stopping at epoch: {epoch}')
            break

    # Save the final trained model after the task
    save_trained_model(model, tokenizer, output_dir)
    pyro.get_param_store().save('pyro_param_store_task-test2.pt')
    return model


In [24]:
os.chdir('/home/pranav24/cs-546-project/SSR/Latest_Weights/QA_QG_Weights')
target_file = "task074_squad1.1_question_generation.json"

with open(target_file, 'r', encoding='utf-8-sig') as f:
    json_data = json.load(f)

instances = json_data['Instances'][0:2500]
input_texts = [str(instance['input']) for instance in instances]
output_texts = [str(instance['output'][0]) if instance['output'] else "" for instance in instances]

# Create Hugging Face Dataset
ds = Dataset.from_dict({'input': input_texts, 'output': output_texts})

# Tokenize the dataset
def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["input"],
        truncation=True,
        padding="max_length",
        max_length=512
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["output"],
            truncation=True,
            padding="max_length",
            max_length=512
        )
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["attention_mask"] = model_inputs.get("attention_mask", None)
    return model_inputs

# Apply tokenization and set format
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["input", "output"])
tokenized_datasets.set_format("torch")

# Split dataset into train and eval
train_size = int(0.8 * len(tokenized_datasets))
train_dataset = tokenized_datasets.select(range(train_size))
eval_dataset = tokenized_datasets.select(range(train_size, len(tokenized_datasets)))

# Create DataLoaders
batch_size = 8  
train_loader_2 = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader_2 = DataLoader(eval_dataset, batch_size=batch_size)

# Define data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Map:   0%|          | 0/2500 [00:00<?, ? examples/s]



In [38]:
!pip install json_repair

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




#### Synthetic Data

In [25]:
import json_repair 
os.chdir('/home/pranav24/cs-546-project/SSR/Synthethic_Data_Generation')
target_file = "qa.train.final_sampled.jsonl"

with open(target_file, 'r', encoding='utf-8-sig') as f:
    json_data = json_repair.loads(f.read())

instances = json_data
input_texts = [str("\n\nContext: "+ instance['input'].split("\n\nContext:")[-1].strip()) for instance in instances]
output_texts = [str(instance['output'][0]) if instance['output'] else "" for instance in instances]

# Create Hugging Face Dataset
ds = Dataset.from_dict({'input': input_texts, 'output': output_texts})
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["input", "output"])
tokenized_datasets.set_format("torch")
train_size = int(1.0 * len(tokenized_datasets))
synthetic_train_dataset = tokenized_datasets.select(range(train_size))
batch_size = 8  
synthetic_loader_1 = DataLoader(synthetic_train_dataset, batch_size=batch_size, shuffle=True)


Map:   0%|          | 0/201 [00:00<?, ? examples/s]

In [26]:
print(os.getcwd())
os.chdir('/home/pranav24/cs-546-project/')
print(os.getcwd())

/home/pranav24/cs-546-project/SSR/Synthethic_Data_Generation
/home/pranav24/cs-546-project


In [27]:
from torch.utils.data import ConcatDataset, DataLoader

# Combine datasets
if synthetic_loader_1 is not None:
    print('combined dataloader')
    combined_dataset = ConcatDataset([train_loader_2.dataset, synthetic_loader_1.dataset])
    combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)
else:
    print('not combined dataloader')
    combined_loader = train_loader_2

combined dataloader


In [30]:
ewc_lambda = 50.0
model_task_2=run_lora_evcl_2(
    combined_loader=combined_loader,
    eval_loader=eval_loader,
    num_epochs= 100,
    batch_size=8,
    learning_rate=2e-4,
    logging_steps=100,
    eval_steps=200,
    save_steps=500,
    output_dir="finetuned-weights-LoRA-EVCL-Task2",
    load_pyro=True,
    best_output_dir="finetuned-weights-LoRA-EVCL-Final-Task2_EVCL_best",
    prev_fisher_info=fisher_info,
    prev_posterior_means=prev_posterior_means,
    ewc_lambda=ewc_lambda,
    synthetic_data_loader=synthetic_loader_1,
    tokenizer=tokenizer,
    model=model
)


using previous pyro params
Training on new task with EWC and synthetic data from previous task...
penalty of ewc
tensor(0.0196, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(0.0439, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(1.6195, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(3.5828, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(3.5857, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(3.5896, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(5.4967, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(6.5233, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(6.5260, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(6.5313, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(7.7422, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(8.5372, device='cuda:0', grad_fn=<AddBackward0>)
penalty of ewc
tensor(8.5398, device='cuda:0', gra

KeyboardInterrupt: 