In [1]:

import pandas as pd
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
hf_token = "llmunlearningsemeval2025organization/olmo-1B-model-semeval25-unlearning"  # Copy token here


In [3]:
## Fetch and load model:
snapshot_download(repo_id='llmunlearningsemeval2025organization/olmo-1B-model-semeval25-unlearning', token=hf_token, local_dir='semeval25-unlearning-1B-model')
model = AutoModelForCausalLM.from_pretrained('semeval25-unlearning-1B-model')

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
retain_train_df = pd.read_parquet('semeval25-unlearning-data/data/retain_train-00000-of-00001.parquet', engine='pyarrow') # Retain split: train set
retain_validation_df = pd.read_parquet('semeval25-unlearning-data/data/retain_validation-00000-of-00001.parquet', engine='pyarrow') # Retain split: validation set
forget_train_df = pd.read_parquet('semeval25-unlearning-data/data/forget_train-00000-of-00001.parquet', engine='pyarrow') # Forget split: train set
forget_validation_df = pd.read_parquet('semeval25-unlearning-data/data/forget_validation-00000-of-00001.parquet', engine='pyarrow') # Forget split: validation set

In [5]:
retain_train_df.to_json('/data/tofu/semeval25-unlearning-data/retain.jsonl'); forget_train_df.to_json('/data/tofu/semeval25-unlearning-data/forget.jsonl')
retain_validation_df.to_json('/data/tofu/semeval25-unlearning-data/retain.jsonl'); forget_validation_df.to_json('/data/tofu/semeval25-unlearning-data/forget.jsonl')

In [5]:
retain_train_df.head()

Unnamed: 0,id,input,output,task,split
0,6adbf83c-5071-4979-bedb-e5184b15650bsc1,"Fredericka Amber was born on December 21, 1969...",number is 889-867-1855. She can be reached at ...,Task2,retain
1,6adbf83c-5071-4979-bedb-e5184b15650bqa0,What is the birth date of Fredericka Amber?,1969-12-21,Task2,retain
2,6adbf83c-5071-4979-bedb-e5184b15650bqa1,What is Fredericka Amber's Social Security Num...,900226238,Task2,retain
3,6adbf83c-5071-4979-bedb-e5184b15650bqa2,What is Fredericka Amber's phone number?,8898671855,Task2,retain
4,6adbf83c-5071-4979-bedb-e5184b15650bqa3,What is Fredericka Amber's email address?,fredericka_amber@me.com,Task2,retain


In [6]:
from hf_olmo import OLMoForCausalLM  
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B")


In [10]:
import torch

In [7]:
message = ["What is Fredericka Amber's email address?"]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
# optional verifying cuda
# inputs = {k: v.to('cuda') for k,v in inputs.items()}
# olmo = olmo.to('cuda')
response = model.generate(**inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])

What is Fredericka Amber's email address? fredericka_amber@me.com


In [59]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import datasets
from utils import get_model_identifiers_from_yaml, add_dataset_index
import os

def convert_raw_data_to_model_format(tokenizer, max_length,  question, answer, model_configs):
    question_start_token, question_end_token, answer_token = model_configs['question_start_tag'], model_configs['question_end_tag'], model_configs['answer_tag']
    new_question = question_start_token + question + question_end_token
    new_answer = answer_token + answer
    full_text = new_question + new_answer
    num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))
    encoded = tokenizer(
        full_text, 
        add_special_tokens=True, 
        max_length=max_length, 
        truncation=True, 
    )
    pad_length = max_length - len(encoded.input_ids)
    pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
    pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
    if len(encoded.input_ids) == max_length:
        label = encoded.input_ids
    else:
        label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)
        
    encoded_answer = tokenizer(
        new_answer, 
        add_special_tokens=True, 
        max_length=max_length, 
        truncation=True, 
    )
        
        
    #change label to -100 for question tokens
#     print(encoded['input_ids'][num_question_tokens], label[num_question_tokens])
    for i in range(num_question_tokens): label[i] = -100
    
    return torch.tensor(pad_input_ids),torch.tensor(label),torch.tensor(pad_attention_mask)
    

class FamilyForgetDataset(Dataset):
    def _init_(self, data_path, tokenizer, model_configs, max_length=512,  unlearn_data_id=0, question_key=None, answer_key=None):
        super(FamilyForgetDataset, self)._init_()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = datasets.Dataset.from_dict(torch.load(data_path))
        self.data = add_dataset_index(self.data)
        self.qk = question_key
        self.ak = answer_key
        self.unlearn_data_id = unlearn_data_id
        self.model_configs = model_configs

    def _len_(self):
        return int(os.environ.get('WORLD_SIZE', 1)) 

    def _getitem_(self, idx):
        pad_input_ids_list = []
        label_list = []
        pad_attention_mask_list = []
        question = self.data[self.unlearn_data_id][self.qk]
        answers = self.data[self.unlearn_data_id][self.ak]
        indices = self.data[self.unlearn_data_id]['index']
        if isinstance(answers, str):
            answers = [answers]

        pad_input_ids_list = []
        label_list = []
        pad_attention_mask_list = []

        for answer in answers:
            converted_data = convert_raw_data_to_model_format(self.tokenizer, self.max_length, question, answer, self.model_configs)
            pad_input_ids_list.append(converted_data[0])
            label_list.append(converted_data[1])
            pad_attention_mask_list.append(converted_data[2])

        return torch.stack(pad_input_ids_list).squeeze(),\
                torch.stack(label_list).squeeze(),\
                torch.stack(pad_attention_mask_list).squeeze(),\
                torch.tensor(indices)
    
def custom_data_collator(samples):
    input_ids = [s[0] for s in samples]
    labels = [s[1] for s in samples]
    attention_mask = [s[2] for s in samples]
    return torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)

In [None]:
from data_module import custom_data_collator, FamilyForgetDataset
from unlearn_trainer import CustomTrainer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, set_seed

import hydra 
import transformers
import os
from pathlib import Path
from omegaconf import OmegaConf
from utils import get_model_identifiers_from_yaml

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

@hydra.main(version_base=None, config_path="config", config_name="finetune")
def main(cfg):
    num_devices = int(os.environ.get('WORLD_SIZE', 1))
    print(f"num_devices: {num_devices}")
    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = {'': local_rank}
    set_seed(cfg.seed)
    os.environ["WANDB_DISABLED"] = "true"
    model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
    model_id = model_cfg["model_id"]

    Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
    # save the cfg file
    #if master process
    if os.environ.get('LOCAL_RANK') is None or local_rank == 0:
        with open(f'{cfg.save_dir}/cfg.yaml', 'w') as f:
            OmegaConf.save(cfg, f)

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    subsample = torch.load(cfg.subsample_path)
    shuffled_unlearn_data_id = int(subsample[cfg.unlearn_data_id])
    torch_format_dataset = FamilyForgetDataset(cfg.data_path, tokenizer=tokenizer, model_configs=model_cfg,max_length=500, unlearn_data_id=shuffled_unlearn_data_id,question_key='question4', answer_key='answer4') 

    batch_size = cfg.batch_size
    gradient_accumulation_steps = cfg.gradient_accumulation_steps
    steps_per_epoch = len(torch_format_dataset)//(batch_size*gradient_accumulation_steps*num_devices)
    num_devices = int(os.environ.get('WORLD_SIZE', 1))
    print(f"num_devices: {num_devices}")

    print("max_steps calc parmas : len(torch_format_dataset)", len(torch_format_dataset), "num_epochs:", cfg.num_epochs, "batch_size:", batch_size, "gradient_accumulation_steps:",gradient_accumulation_steps, "num_devices:", num_devices, "steps_per_epoch:",steps_per_epoch)
    max_steps = int(cfg.num_epochs*len(torch_format_dataset))//(batch_size*gradient_accumulation_steps*num_devices)
    
    lr = float(model_cfg["reinforce_lr"])
    
    training_args = transformers.TrainingArguments(
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=max(1, max_steps//cfg.num_epochs),
            max_steps=max_steps,
            learning_rate=lr,
            lr_scheduler_type=cfg.lr_scheduler_type,
            bf16=True,
            bf16_full_eval=True,
            logging_steps=max(1,max_steps//20),
            logging_dir=f'{cfg.save_dir}/logs',
            output_dir=cfg.save_dir,
            optim="paged_adamw_32bit",
            save_steps=max_steps,
            save_strategy="steps",
            save_only_model=True,
            ddp_find_unused_parameters= False,
            evaluation_strategy="no",
            deepspeed='config/ds_config.json',
            weight_decay = cfg.weight_decay,
            seed = cfg.seed,
        )

    import re
    path_found = False
    for file in os.listdir(cfg.model_path):
        if re.search(r"pytorch.*\.bin", file):
            path_found = True
            break
        
        if re.search(r"model-*\.safetensors", file):
            path_found = True
            break

    if path_found:
        print("INSIDE PATTH FOUND")
        config = AutoConfig.from_pretrained(model_id)

        print("Loading from checkpoint")
        model = AutoModelForCausalLM.from_pretrained(cfg.model_path, config=config, use_flash_attention_2=model_cfg["flash_attention2"]=="true", torch_dtype=torch.bfloat16, token=os.environ['HF_TOKEN'], trust_remote_code = True)
    
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, use_flash_attention_2=model_cfg["flash_attention2"]=="true", torch_dtype=torch.bfloat16, trust_remote_code = True)
    
    # Hot fix for https://discuss.huggingface.co/t/help-with-llama-2-finetuning-setup/50035
    model.generation_config.do_sample = True

    if model_cfg["gradient_checkpointing"] == "true":
        model.gradient_checkpointing_enable()

    trainer = CustomTrainer(
        model=model,
        train_dataset=torch_format_dataset,
        eval_dataset=torch_format_dataset,
        args=training_args,
        data_collator=custom_data_collator,
    )
    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    trainer.train()


    model.save_pretrained(cfg.save_dir)
    tokenizer.save_pretrained(cfg.save_dir)


In [None]:
import torch
from transformers import Trainer
import torch.nn.functional as F
import os
import copy
import numpy as np

import deepspeed
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids, labels, attention_mask = inputs
        # forward pass
        outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
        # logits = outputs.get("logits")
        loss = outputs.loss
        # # compute custom loss (suppose one has 3 labels with different weights)
        # loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
        # loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)


class CustomFamilyTrainerForgetting(Trainer):
    def _init_(self, *args, **kwargs):
        self.loss_type = kwargs.pop('forget_loss')
        self.save_dir = kwargs.pop('save_dir')
        self.save_step_pattern = kwargs.pop('save_step_pattern')
        super(CustomFamilyTrainerForgetting, self)._init_(*args, **kwargs)
        
        if self.loss_type == "npo":
            self.beta = 0.1
            self.outputs_f_ref_logits = None

    def compute_loss(self, model, inputs, return_outputs=False):
        if self.loss_type == "ga":
            forget_inputs = inputs
            input_ids, labels, attention_mask = inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss = outputs.loss
            forget_loss = forget_loss * -1
            loss = forget_loss
            
        elif self.loss_type == 'npo':
            forget_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            neg_log_ratio = self.outputs_f_ref_logits - outputs.logits
            print(neg_log_ratio)
            loss = -F.logsigmoid(self.beta * neg_log_ratio).mean() * 2 / self.beta

        return (loss, outputs) if return_outputs else loss
        
    
    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)

    def evaluate(
        self,
        eval_dataset = None,
        ignore_keys = None,
        metric_key_prefix = "eval",
    ):
        curr_step = self.state.global_step
        if self.save_step_pattern == "log":
            import math
            if curr_step not in [1, 2, 4, 8, 16, 32]: 
                return

        curr_save_dir = os.path.join(self.save_dir, f"checkpoint-{curr_step}")
        self.save_model(curr_save_dir)
                        
    def e_prepare_deepspeed(self, model):
        # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        config_kwargs = copy.deepcopy(deepspeed_plugin.deepspeed_config)

        if model is not None:
            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                    # Note that stage3_prefetch_bucket_size can produce DeepSpeed messages like: Invalidate trace cache @ step 0: expected module 1, but got module 0
                    # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                        }
                    )

        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        config_kwargs["optimizer"] = {"type": None}
        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        #set the gradients to false for every parameter
        for param in model.parameters():
            param.requires_grad = False
        
        return model

In [11]:
def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

In [12]:
find_all_linear_names(model)

['down_proj', 'v_proj', 'up_proj', 'gate_proj', 'q_proj', 'o_proj', 'k_proj']

In [13]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [14]:
print_trainable_parameters(model)

trainable params: 1279787008 || all params: 1279787008 || trainable%: 100.0


In [None]:
dataset = [retain_train_df, forget_train_df]

In [None]:
import torch
import torch.nn.functional as F


def cross_entropy_loss(logits, labels, state_size):
    """Cross-entropy loss with ignore_index=-100."""
    return F.cross_entropy(logits.view(-1, state_size), labels.view(-1), ignore_index=-100) 


def grad_ascent_loss(model, X_f, y_f, state_size):
    """Gradient ascent loss using cross-entropy."""
    logits = model.forward(X_f)  # Forward pass to get logits
    loss = - cross_entropy_loss(logits, y_f, state_size)
    return loss


def grad_descent_loss(model, X_r, y_r, state_size):
    """Gradient descent loss using cross-entropy."""
    logits = model.forward(X_r)  # Forward pass to get logits
    loss = cross_entropy_loss(logits, y_r, state_size)
    return loss


def grad_diff_loss(model, X_f, y_f, X_r, y_r, state_size):
    """Gradient difference loss using cross-entropy."""
    # Forward pass for forget data
    logits_f = model.forward(X_f)
    loss_f = - cross_entropy_loss(logits_f, y_f, state_size)
    

    # Forward pass for retain data
    logits_r = model.forward(X_r)
    loss_r = cross_entropy_loss(logits_r, y_r, state_size)

    return loss_f + loss_r




def grad_diff_kl_forget_loss(model, X_f, y_f, X_r, y_r, finetuned_model, state_size):
    """Gradient difference with KL divergence for forgetting using cross-entropy."""

    logits_f = model.forward(X_f)
    loss_f = -cross_entropy_loss(logits_f, y_f, state_size)  # Cross-entropy with ignore_index=-100


    logits_r = model.forward(X_r)
    loss_r = cross_entropy_loss(logits_r, y_r, state_size)  # Cross-entropy with ignore_index=-100


    # Mask to ignore padding for KL divergence
    valid_mask_f = (y_f != -100).float()  # Mask for valid tokens in forget data

    # KL divergence with fine-tuned model on forget data, ignoring padding
    with torch.no_grad():
        logits_f_finetuned = finetuned_model.forward(X_f)


    kl_forget = F.kl_div(
        F.log_softmax(logits_f, dim=-1) ,  
        F.softmax(logits_f_finetuned, dim=-1) ,
        reduction='none'
    )
    kl_forget = kl_forget.sum(dim=-1) * valid_mask_f  # Apply valid mask to KL divergence
    # Apply mask to logits before computing KL divergence
    kl_forget = kl_forget.sum()/logits_f.size(0)  # Average over the batch

    return loss_f + loss_r +  kl_forget


def kl_loss(model, X_f, y_f, X_r, y_r, finetuned_model, state_size):
    """KL divergence-based retain loss using cross-entropy"""

   
    logits_f = model.forward(X_f)
    loss_f = -cross_entropy_loss(logits_f, y_f, state_size)  # Cross-entropy with ignore_index=-100

    # Mask to ignore padding for KL divergence
    valid_mask_r = (y_r != -100).float()  # Mask for valid tokens in retain data

    # KL divergence with fine-tuned model on retain data, ignoring padding
    logits_r = model.forward(X_r)
    with torch.no_grad():
        logits_r_finetuned = finetuned_model.forward(X_r)

    kl_retain = F.kl_div(
        F.log_softmax(logits_r, dim=-1) ,
        F.softmax(logits_r_finetuned, dim=-1) ,
        reduction='none'
    )
    kl_retain = kl_retain.sum(dim=-1) * valid_mask_r  # Apply valid mask to KL divergence
    kl_retain = kl_retain.sum()/logits_r.size(0)  # Average over the batch

    return loss_f + kl_retain


def NPO_loss(model, X_f, y_f, finetuned_model, beta, state_size):
    """Non-Paired Objective (NPO) loss using cross-entropy and summing over tokens per sample."""

    valid_mask_f = (y_f != -100).float()  # Mask for valid tokens, shape [batch_size, seq_len]

    # Forward pass for forget data with the current model
    logits_f = model.forward(X_f)  # Shape: [batch_size, seq_len, state_size]

    y_f_masked = y_f.clone()
    y_f_masked[y_f_masked == -100] = 0  # Replace -100 with 0 so gather can function

    # Gather the log probabilities corresponding to the true labels
    outputs_f = F.log_softmax(logits_f, dim=-1).gather(2, y_f_masked.unsqueeze(-1)).squeeze(-1)  # Shape: [batch_size, seq_len]

    # Forward pass for forget data with the fine-tuned model
    with torch.no_grad():
        logits_f_finetuned = finetuned_model.forward(X_f)
        outputs_f_finetuned = F.log_softmax(logits_f_finetuned, dim=-1).gather(2, y_f_masked.unsqueeze(-1)).squeeze(-1)
        assert outputs_f_finetuned.shape  == logits_f_finetuned.shape[:2]  # Check shape

    # Compute the negative log-ratio, applying the valid mask to ignore padding tokens
    neg_log_ratio = (outputs_f_finetuned - outputs_f) * valid_mask_f  # Shape: [batch_size, seq_len]

    # Sum over the tokens per sample (along the sequence length dimension)
    neg_log_ratio_sum = neg_log_ratio.sum(dim=1)  # Shape: [batch_size]
    #import pdb; pdb.set_trace()
    # Compute the NPO loss by averaging over the batch, ignoring padding 
    loss = -F.logsigmoid(beta * neg_log_ratio_sum).mean() * 2 / beta

    return loss




# def NPO_AVE(model, X_f, y_f, finetuned_model, beta, state_size):

#     """Non-Paired Objective (NPO) loss using cross-entropy and average over tokens per sample."""

#     valid_mask_f = (y_f != -100).float()  # Mask for valid tokens, shape [batch_size, seq_len]

#     # Forward pass for forget data with the current model
#     logits_f = model.forward(X_f)  # Shape: [batch_size, seq_len, state_size]

#     y_f_masked = y_f.clone()
#     y_f_masked[y_f_masked == -100] = 0  # Replace -100 with 0 so gather can function

#     # Gather the log probabilities corresponding to the true labels
#     outputs_f = F.log_softmax(logits_f, dim=-1).gather(2, y_f_masked.unsqueeze(-1)).squeeze(-1)  # Shape: [batch_size, seq_len]

#     # Forward pass for forget data with the fine-tuned model
#     with torch.no_grad():
#         logits_f_finetuned = finetuned_model.forward(X_f)
#         outputs_f_finetuned = F.log_softmax(logits_f_finetuned, dim=-1).gather(2, y_f_masked.unsqueeze(-1)).squeeze(-1)
#         assert outputs_f_finetuned.shape  == logits_f_finetuned.shape[:2]  # Check shape

#     # Compute the negative log-ratio, applying the valid mask to ignore padding tokens
#     neg_log_ratio = (outputs_f_finetuned - outputs_f) * valid_mask_f  # Shape: [batch_size, seq_len]

#     # Average over the tokens per sample (along the sequence length dimension), this is the only line different from NPO_loss
#     neg_log_ratio_sum = neg_log_ratio.sum(dim=1)/valid_mask_f.sum(dim=1)  # Shape: [batch_size]
#     #import pdb; pdb.set_trace()
#     # Compute the NPO loss by averaging over the batch, ignoring padding 
#     loss = -F.logsigmoid(beta * neg_log_ratio_sum).mean() * 2 / beta

#     return loss



# def NPO_NO_REF(model, X_f, y_f, finetuned_model, beta, state_size):

#     """Non-Paired Objective (NPO) loss using cross-entropy and average over tokens per sample."""

#     valid_mask_f = (y_f != -100).float()  # Mask for valid tokens, shape [batch_size, seq_len]

#     # Forward pass for forget data with the current model
#     logits_f = model.forward(X_f)  # Shape: [batch_size, seq_len, state_size]

#     y_f_masked = y_f.clone()
#     y_f_masked[y_f_masked == -100] = 0  # Replace -100 with 0 so gather can function

#     # Gather the log probabilities corresponding to the true labels
#     outputs_f = F.log_softmax(logits_f, dim=-1).gather(2, y_f_masked.unsqueeze(-1)).squeeze(-1)  # Shape: [batch_size, seq_len]

   

#     # Compute the negative log-ratio, applying the valid mask to ignore padding tokens
#     neg_log=  - outputs_f * valid_mask_f  # Shape: [batch_size, seq_len]

#     # Sum over the tokens per sample (along the sequence length dimension) this is the only line different from SimNPO_loss
#     neg_log_sum = neg_log.sum(dim=1)  # Shape: [batch_size]

#     #import pdb; pdb.set_trace()
#     # Compute the NPO loss by averaging over the batch, ignoring padding 
#     loss = -F.logsigmoid(beta * neg_log_sum).mean() * 2 / beta

#     return loss



def NPO_KL(model, X_f, y_f, X_r, y_r, finetuned_model, beta, state_size):
    """NPO + KL loss."""

   
    forget_loss = NPO_loss(model, X_f, y_f, finetuned_model, beta, state_size)

  
    valid_mask_r = (y_r != -100).float()  # Mask for valid tokens in retain data

    logits_r = model.forward(X_r)  # [batch_size, seq_len, state_size]
    with torch.no_grad():
        logits_r_finetuned = finetuned_model.forward(X_r)


    retain_probs_current = F.log_softmax(logits_r, dim=-1) 
    retain_probs_finetuned = F.softmax(logits_r_finetuned, dim=-1)
    retain_loss = F.kl_div(retain_probs_current, retain_probs_finetuned, reduction='none')

    retain_loss = retain_loss.sum(dim=-1) * valid_mask_r  # Apply mask to KL divergence
    retain_loss = retain_loss.sum()/logits_r.size(0)  # Average over the batch


    # Final NPO + KL loss
    return forget_loss + retain_loss




def NPO_RT(model, X_f, y_f, X_r, y_r, finetuned_model, beta, state_size):
    """NPO + Retain loss."""


    forget_loss = NPO_loss(model, X_f, y_f, finetuned_model, beta, state_size)

    
    logits_r = model.forward(X_r)  # [batch_size, seq_len, state_size]
    retain_loss = cross_entropy_loss(logits_r, y_r, state_size)


    # Final NPO + retain loss
    return forget_loss + retain_loss




def SimNPO_loss(model, X_f, y_f, finetuned_model, beta, state_size):
    """Non-Paired Objective (NPO) loss using cross-entropy and summing over tokens per sample."""

    valid_mask_f = (y_f != -100).float()  # Mask for valid tokens, shape [batch_size, seq_len]

    # Forward pass for forget data with the current model
    logits_f = model.forward(X_f)  # Shape: [batch_size, seq_len, state_size]

    y_f_masked = y_f.clone()
    y_f_masked[y_f_masked == -100] = 0  # Replace -100 with 0 so gather can function

    # Gather the log probabilities corresponding to the true labels
    outputs_f = F.log_softmax(logits_f, dim=-1).gather(2, y_f_masked.unsqueeze(-1)).squeeze(-1)  # Shape: [batch_size, seq_len]

   

    # Compute the negative log-ratio, applying the valid mask to ignore padding tokens
    neg_log=  - outputs_f * valid_mask_f  # Shape: [batch_size, seq_len]

    # Sum over the tokens per sample (along the sequence length dimension)
    neg_log_sum = neg_log.sum(dim=1)/valid_mask_f.sum(dim=1)  # Shape: [batch_size]

    #import pdb; pdb.set_trace()
    # Compute the NPO loss by averaging over the batch, ignoring padding 
    loss = -F.logsigmoid(beta * neg_log_sum).mean() * 2 / beta

    return loss

def SimNPO_KL(model, X_f, y_f, X_r, y_r, finetuned_model, beta, state_size):
    """SimNPO + KL loss."""

   
    forget_loss = SimNPO_loss(model, X_f, y_f, finetuned_model, beta, state_size)

  
    valid_mask_r = (y_r != -100).float()  # Mask for valid tokens in retain data

    logits_r = model.forward(X_r)  # [batch_size, seq_len, state_size]
    with torch.no_grad():
        logits_r_finetuned = finetuned_model.forward(X_r)


    retain_probs_current = F.log_softmax(logits_r, dim=-1) 
    retain_probs_finetuned = F.softmax(logits_r_finetuned, dim=-1)
    retain_loss = F.kl_div(retain_probs_current, retain_probs_finetuned, reduction='none')

    retain_loss = retain_loss.sum(dim=-1) * valid_mask_r  # Apply mask to KL divergence
    retain_loss = retain_loss.sum()/logits_r.size(0)  # Average over the batch


    # Final NPO + KL loss
    return forget_loss + retain_loss



def SimNPO_RT(model, X_f, y_f, X_r, y_r, finetuned_model, beta, state_size):
    """SimNPO + Retain loss."""


    forget_loss = SimNPO_loss(model, X_f, y_f, finetuned_model, beta, state_size)

    
    logits_r = model.forward(X_r)  # [batch_size, seq_len, state_size]
    retain_loss = cross_entropy_loss(logits_r, y_r, state_size)


    # Final NPO + retain loss
    return forget_loss + retain_loss


def compute_loss(model, loss_type, X_f, y_f, X_r=None, y_r=None, y_idk=None, finetuned_model=None, state_size=None, beta=None):
    """
    Compute loss based on the specified loss type.

    Parameters:
    - model: The current model being trained.
    - loss_type: The type of loss to compute (e.g., 'grad_ascent', 'grad_diff', 'NPO').
    - X_f: Forget data ids.
    - y_f: Forget data labels.
    - X_r: Retain data ids.
    - y_r: Retain data labels.
    - y_idk: IDK data labels.

    - finetuned_model: The pre-trained model for comparison in KL loss functions.
    - state_size: Number of states.
    - beta: Scaling factor used in certain loss functions (e.g., non-paired DPO, kto).

    Returns:
    - Computed loss based on the selected loss_type.
    """

    if loss_type == 'grad_ascent':
        return grad_ascent_loss(model, X_f, y_f, state_size)
    elif loss_type == 'grad_descent':
        return grad_descent_loss(model, X_r, y_r, state_size)
    elif loss_type == 'grad_diff':
        return grad_diff_loss(model, X_f, y_f, X_r, y_r, state_size)
    elif loss_type == 'grad_diff_kl_forget':
        return grad_diff_kl_forget_loss(model, X_f, y_f, X_r, y_r, finetuned_model, state_size)
    elif loss_type == 'kl':
        return kl_loss(model, X_f, y_f, X_r, y_r, finetuned_model, state_size)
    elif loss_type == 'NPO':
        return NPO_loss(model, X_f, y_f, finetuned_model, beta, state_size)
    elif loss_type == 'NPO_KL':
        return NPO_KL(model, X_f, y_f, X_r, y_r, finetuned_model, beta, state_size)
    elif loss_type == 'NPO_RT':
        return NPO_RT(model, X_f, y_f, X_r, y_r, finetuned_model, beta, state_size)
    elif loss_type == 'SimNPO':
        return SimNPO_loss(model, X_f, y_f, finetuned_model, beta, state_size)
    elif loss_type == 'SimNPO_KL':
        return SimNPO_KL(model, X_f, y_f, X_r, y_r, finetuned_model, beta, state_size)
    elif loss_type == 'SimNPO_RT':
        return SimNPO_RT(model, X_f, y_f, X_r, y_r, finetuned_model, beta, state_size)
   
       
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

In [44]:
import numpy as np
import random
import torch
import os
import pickle
from sklearn.model_selection import train_test_split


def generate_sequences(transition_matrix, initial_state_probs, num_sequences, seq_length, state_size):
    # Convert transition_matrix and initial_state_probs to tensors for efficient sampling
    transition_matrix = torch.tensor(transition_matrix, dtype=torch.float32)
    initial_state_probs = torch.tensor(initial_state_probs, dtype=torch.float32)

    # Initialize sequences array
    sequences = torch.zeros((num_sequences, seq_length), dtype=torch.long)

    # Sample initial states based on the initial state probabilities
    sequences[:, 0] = torch.multinomial(initial_state_probs, num_samples=num_sequences, replacement=True)

    for t in range(1, seq_length):
        # Get current states
        current_states = sequences[:, t - 1]

        # Get transition probabilities for the current states
        probs = transition_matrix[current_states]  # Shape: [num_sequences, state_size]

        # Sample next states for all sequences at once
        next_states = torch.multinomial(probs, num_samples=1).squeeze(1)  # Shape: [num_sequences]

        # Store the next states
        sequences[:, t] = next_states

    return sequences.numpy().tolist()

def prepare_datasets(state_size=10, seq_lengths=[20,20,20], num_sequences=[10000,5000,5000], seed=42, test_size=0.2,leakage=0.2):
    # Set random seed for reproducibility
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # Create data directory if it doesn't exist

    # Define non-intersecting subsets for each distribution
    subset_size = state_size // 3
    retain_states = list(range(0, subset_size))
    forget1_states = list(range(subset_size, 2 * subset_size))
    forget2_states = list(range(2 * subset_size, 3 * subset_size))

    # Initialize transition matrices with small leakage
    #leakage is the probability of transitioning to states outside the subset
    retain_transition_matrix = np.full((state_size, state_size), leakage / (state_size - len(retain_states)))
    forget_transition_matrix1 = np.full((state_size, state_size), leakage / (state_size - len(forget1_states)))
    forget_transition_matrix2 = np.full((state_size, state_size), leakage / (state_size - len(forget2_states)))

    # Set higher probabilities for transitions within the subset
    for s in retain_states:
        retain_transition_matrix[s, retain_states] = (1.0 - leakage) / len(retain_states)
    for s in forget1_states:
        forget_transition_matrix1[s, forget1_states] = (1.0 - leakage) / len(forget1_states)
    for s in forget2_states:
        forget_transition_matrix2[s, forget2_states] = (1.0 - leakage) / len(forget2_states)

    # Normalize the transition matrices
    retain_transition_matrix = retain_transition_matrix / retain_transition_matrix.sum(axis=1, keepdims=True)
    forget_transition_matrix1 = forget_transition_matrix1 / forget_transition_matrix1.sum(axis=1, keepdims=True)
    forget_transition_matrix2 = forget_transition_matrix2 / forget_transition_matrix2.sum(axis=1, keepdims=True)

    # Initial state distributions favoring the subsets
    retain_initial_state_probs = np.full(state_size, leakage / (state_size - len(retain_states)))
    retain_initial_state_probs[retain_states] = (1.0 - leakage) / len(retain_states)
    retain_initial_state_probs = retain_initial_state_probs / retain_initial_state_probs.sum()

    forget_initial_state_probs1 = np.full(state_size, leakage / (state_size - len(forget1_states)))
    forget_initial_state_probs1[forget1_states] = (1.0 - leakage) / len(forget1_states)
    forget_initial_state_probs1 = forget_initial_state_probs1 / forget_initial_state_probs1.sum()

    forget_initial_state_probs2 = np.full(state_size, leakage / (state_size - len(forget2_states)))
    forget_initial_state_probs2[forget2_states] = (1.0 - leakage) / len(forget2_states)
    forget_initial_state_probs2 = forget_initial_state_probs2 / forget_initial_state_probs2.sum()

    # Generate sequences
    retain_sequences = generate_sequences(
        retain_transition_matrix, retain_initial_state_probs,
        num_sequences[0], seq_lengths[0], state_size
    )
    forget_sequences1 = generate_sequences(
        forget_transition_matrix1, forget_initial_state_probs1,
        num_sequences[1], seq_lengths[1], state_size
    )
    forget_sequences2 = generate_sequences(
        forget_transition_matrix2, forget_initial_state_probs2,
        num_sequences[2], seq_lengths[2], state_size
    )

    # Label sequences with their source transition matrix
    retain_sequences_labeled = [('retain', seq) for seq in retain_sequences]
    forget_sequences1_labeled = [('forget1', seq) for seq in forget_sequences1]
    forget_sequences2_labeled = [('forget2', seq) for seq in forget_sequences2]

    # Split retain, forget1, and forget2 sequences into train and test sets
    retain_train_seqs, retain_test_seqs = train_test_split(retain_sequences_labeled, test_size=test_size, random_state=seed)
    forget1_train_seqs, forget1_test_seqs = train_test_split(forget_sequences1_labeled, test_size=test_size, random_state=seed)
    forget2_train_seqs, forget2_test_seqs = train_test_split(forget_sequences2_labeled, test_size=test_size, random_state=seed)

    # Combine forget1 and forget2 sequences
    forget_train_seqs = forget1_train_seqs + forget2_train_seqs
    forget_test_seqs = forget1_test_seqs + forget2_test_seqs

    # Shuffle forget sequences after combining
    random.shuffle(forget_train_seqs)
    random.shuffle(forget_test_seqs)

    # Combine retain and forget sequences
    all_train_sequences = retain_train_seqs + forget_train_seqs
    all_test_sequences = retain_test_seqs + forget_test_seqs

    # Shuffle all combined sequences
    random.shuffle(all_train_sequences)
    random.shuffle(all_test_sequences)

    # Find the maximum sequence length for padding
    max_seq_length = max(seq_lengths)

    # Save datasets to files
    data = {
        'retain_train_sequences': retain_train_seqs,
        'retain_test_sequences': retain_test_seqs,
        'forget1_train_sequences': forget1_train_seqs,
        'forget1_test_sequences': forget1_test_seqs,
        'forget2_train_sequences': forget2_train_seqs,
        'forget2_test_sequences': forget2_test_seqs,
        'forget_train_sequences': forget_train_seqs,
        'forget_test_sequences': forget_test_seqs,
        'all_train_sequences': all_train_sequences,
        'all_test_sequences': all_test_sequences,
        'max_seq_length': max_seq_length,
        'state_size': state_size,
        'seq_lengths': seq_lengths,
        'num_sequences': num_sequences,
        'seed': seed,
        'retain_initial_state_probs': retain_initial_state_probs,
        'forget_initial_state_probs1': forget_initial_state_probs1,
        'forget_initial_state_probs2': forget_initial_state_probs2,
        'retain_transition_matrix': retain_transition_matrix,
        'forget_transition_matrix1': forget_transition_matrix1,
        'forget_transition_matrix2': forget_transition_matrix2,
        'leakage': leakage,
    }

    if 1:
        return data

In [45]:
data = prepare_datasets()

In [49]:
data['forget_transition_matrix2']

array([[0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,
        0.1       , 0.1       , 0.1       , 0.1       , 0.1       ],
       [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,
        0.1       , 0.1       , 0.1       , 0.1       , 0.1       ],
       [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,
        0.1       , 0.1       , 0.1       , 0.1       , 0.1       ],
       [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,
        0.1       , 0.1       , 0.1       , 0.1       , 0.1       ],
       [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,
        0.1       , 0.1       , 0.1       , 0.1       , 0.1       ],
       [0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,
        0.1       , 0.1       , 0.1       , 0.1       , 0.1       ],
       [0.02857143, 0.02857143, 0.02857143, 0.02857143, 0.02857143,
        0.02857143, 0.26666667, 0.26666667, 0.26666667, 0.02857143],
       [0.02857143, 0.02857143, 0.0285714

In [50]:
transition_matrices = {
        'retain': data['retain_transition_matrix'],
        'forget1': data['forget_transition_matrix1'],
        'forget2': data['forget_transition_matrix2'],
    }


In [17]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader


# Custom Dataset class with padding
class OneHotDataset(Dataset):
    def __init__(self, sequences_labeled, max_seq_length):
        self.sequences_labeled = sequences_labeled
        self.max_seq_length = max_seq_length - 1  # Adjust for input_ids and labels length

    def __len__(self):
        return len(self.sequences_labeled)

    def __getitem__(self, idx):
        label, seq = self.sequences_labeled[idx]
        input_ids = seq[:-1]  # All except last token
        labels = seq[1:]      # All except first token

        # Padding
        padding_length = self.max_seq_length - len(input_ids)
        input_ids_padded = input_ids + [0] * padding_length
        labels_padded = labels + [-100] * padding_length  # Use -100 to ignore padding in loss

        return {
            'input_ids': torch.tensor(input_ids_padded, dtype=torch.long),
            'labels': torch.tensor(labels_padded, dtype=torch.long),
            'label': label
        }
max_seq_length = 512
retain_train_dataset = OneHotDataset(retain_train_df, max_seq_length)
forget_train_dataset = OneHotDataset(forget_train_df, max_seq_length)
 
retain_test_dataset = OneHotDataset(retain_validation_df, max_seq_length)
forget_test_dataset = OneHotDataset(forget_validation_df, max_seq_length)

In [52]:
forget_train_dataloader = DataLoader(forget_train_dataset, batch_size=8, shuffle=True)
forget_train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7fab469bbfd0>

In [53]:
retrain_train_dataloader = DataLoader(retain_train_dataset, batch_size=8, shuffle=True)
retain_train_dataset

<__main__.OneHotDataset at 0x7fa98f5ce5f0>

In [58]:
retain_train_dataset.__getitem__("6adbf83c-5071-4979-bedb-e5184b15650bsc1")

KeyError: '6adbf83c-5071-4979-bedb-e5184b15650bsc1'

In [54]:
for forget_batch, retain_batch in zip(forget_train_dataloader, retrain_train_dataloader):


            # Prepare forget batch data
            input_ids_f = forget_batch['input_ids']
            labels_f = forget_batch['labels']
            
            # Prepare retain batch data
            input_ids_r = retain_batch['input_ids']
            labels_r = retain_batch['labels']

            print(labels_f,labels_r)

KeyError: 1052

In [35]:
state_size=10
seq_length_retain=20
seq_length_forget1=20
seq_length_forget2=20
num_retain_sequences=10000
num_forget_sequences1=5000
num_forget_sequences2=5000
data_dir="data"
data_seed=42
training_seed=42
unlearning_seed=42
test_size=0.2
leakage=0.2


learning_rate = 5e-4
batch_size =8
epoch = 5


In [28]:
embedding_count = sum(1 for layer in model.modules() if isinstance(layer, nn.Embedding))
print(f"Total number of embedding layers: {embedding_count}")

Total number of embedding layers: 1


In [24]:
device = torch.device('cuda' if torch.cuda.is_available()  else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [29]:
loss_type ='NPO_KL'

In [32]:
finetuned_model = model.to(device)
intial_model = model.to(device)

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

def evaluate_model(model, dataloaders, criterion, state_size, device, retrain_model=None):

    
    model.eval()
    if retrain_model:
        retrain_model.eval()  # Put retrain model in eval mode if provided
    
    results = {}
    with torch.no_grad():
        for name, dataloader in dataloaders.items():

            
            
            total_loss = 0
            total_tokens = 0
            
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)

                # Forward pass through the main model
                logits = model(input_ids)  # [batch_size, seq_length, state_size]

               
                if retrain_model:
                    # Forward pass through the retrain model
                    retrain_logits = retrain_model(input_ids)  # [batch_size, seq_length, state_size]
                    
                    # Compute log probabilities of the current model and retrain model
                    log_probs = F.log_softmax(logits, dim=-1)
                    retrain_probs = F.softmax(retrain_logits, dim=-1)

                    # Mask to identify non-padded tokens
                    non_pad_mask = (labels != -100)
                    
                    # Compute KL divergence only on non-padded tokens
                    kl_div = F.kl_div(log_probs, retrain_probs, reduction='none').sum(dim=-1)  # Sum over state_size
                    non_pad_tokens = non_pad_mask.sum().item()

                    total_loss += kl_div[non_pad_mask].sum().item()  # Sum KL divergence for non-padded tokens
                else:
                    # Standard cross-entropy loss for non-padded tokens
                    loss = criterion(logits.view(-1, state_size), labels.view(-1))
                    non_pad_tokens = (labels != -100).sum().item()
                    total_loss += loss.item() * non_pad_tokens

                total_tokens += non_pad_tokens

            avg_loss = total_loss / total_tokens
            results[name] = avg_loss
            
    return results

def evaluate_with_transitions(model, dataloaders, transition_matrices, state_size, device, retrain_model=None):
    
    model.eval()
    if retrain_model:
        retrain_model.eval()  # Set the retrain model in evaluation mode if provided

    results = {}

    with torch.no_grad():
        for name, dataloader in dataloaders.items():

            total_kl_div = 0
            total_tokens = 0
            
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)  # [batch_size, seq_length]
                labels = batch['labels'].to(device)  # [batch_size, seq_length]
                labels_seq = batch['label']  # List of sequence labels indicating retain, forget1, or forget2

                logits = model(input_ids)  # [batch_size, seq_length, state_size]
                
                
                if retrain_model:
                    # Compute logits for the retrain model
                    retrain_logits = retrain_model(input_ids)  # [batch_size, seq_length, state_size]
                    retrain_probs = torch.softmax(retrain_logits, dim=-1)  # Softmax over logits of the retrain model
                else:
                    # Compute transition probabilities for each sequence from the provided transition matrices
                    batch_size, seq_length = input_ids.shape
                    transition_probs = np.zeros((batch_size, seq_length, state_size))

                    label_to_matrix = {
                        'retain': transition_matrices['retain'],
                        'forget1': transition_matrices['forget1'],
                        'forget2': transition_matrices['forget2']
                    }

                    for i, seq_label in enumerate(labels_seq):
                        trans_matrix = label_to_matrix.get(seq_label)
                        if trans_matrix is None:
                            continue
                        current_states = input_ids[i].cpu().numpy()
                        transition_probs[i] = trans_matrix[current_states]

                    # Convert transition probabilities to torch tensor
                    transition_probs = torch.tensor(transition_probs, dtype=torch.float32, device=device)  # [batch_size, seq_length, state_size]

                # Get the mask for non-padding tokens
                non_pad_mask = (labels != -100)  # [batch_size, seq_length]

                # Compute log probabilities for the main model
                log_probs = F.log_softmax(logits, dim=-1)  # [batch_size, seq_length, state_size]
                
                if retrain_model:
                    # Compute KL divergence between main model and retrain model probabilities
                    kl_div = F.kl_div(log_probs, retrain_probs, reduction='none').sum(-1)  # [batch_size, seq_length]
                else:
                    # Compute KL divergence between main model and transition probabilities
                    kl_div = F.kl_div(log_probs, transition_probs, reduction='none').sum(-1)  # [batch_size, seq_length]

                # Mask out padding tokens
                kl_div = kl_div * non_pad_mask.float()

                total_kl_div += kl_div.sum().item()
                total_tokens += non_pad_mask.sum().item()

            avg_kl_div = total_kl_div / total_tokens   # Average KL divergence over non-padded tokens
            results[name] = avg_kl_div

    return results

In [None]:
from transformers import AdamW

optimizer = AdamW(model.parameters(), lr= learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=-100)