In [None]:
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from scipy.spatial.distance import cdist
from tqdm import tqdm
from config import Config2
from template import LLAMA3_CHAT_TEMPLATE
from typing import Dict, Tuple
from peft import PeftModel
from utils import read_file

  from .autonotebook import tqdm as notebook_tqdm


### Details to run

- First run a warm up model (ascent.py, sim_vanilla.py), check all the requirements,path correctly, model_id, adaptor_id.
- Check your data setup. this can be confusing, if you run all the setups.
- Calculate the representations 
- Calculate the shift, need to do again for another model 

In [None]:
ds = 'data_setup_1' #change this accordingly
df = pd.read_file('./data/full_data.parquet')

In [None]:
model_id = 'path to the finetuned model, which is basically pre unlearning model'
adaptor_id = 'your warm up model path'
algorithm ='gd' # or snpo

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(model_id,  device_map = "auto", torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base_model, adaptor_id, device_map="auto", torch_dtype=torch.bfloat16) 
model = model.merge_and_unload()


tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def get_hidden_states_mean(
    df, model, tokenizer, device, batch_size=1,
    exclude_special_tokens: bool = True,
    max_length: int = 512
):
    """
    Returns a (N, H) array of mean-pooled penultimate-layer embeddings.
    - Pools over non-padding tokens (attention_mask==1).
    - If exclude_special_tokens=True, removes BOS/CLS/EOS/etc. from the average.
    """
    texts = (df['question_f'] + df['answer_f']).tolist()
    all_embeddings = []

    model.eval()
    print('Now extracting mean-pooled hidden reps')

    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            return_tensors='pt',
            max_length=max_length,
            return_special_tokens_mask=True  # needed to optionally drop specials
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

        # (B, T, H)
        penultimate = outputs.hidden_states[-2]

        # Build mask: start from attention_mask (exclude padding)
        # Shape: (B, T, 1)
        mask = inputs['attention_mask'].unsqueeze(-1)

        if exclude_special_tokens and 'special_tokens_mask' in inputs:
            # special_tokens_mask: 1 for special tokens -> set them to 0 in our pooling mask
            specials = inputs['special_tokens_mask'].unsqueeze(-1)
            mask = mask * (1 - specials)

        # Avoid division by zero in rare degenerate cases
        lengths = mask.sum(dim=1).clamp(min=1)  # (B, 1)

        # Mean pool
        summed = (penultimate * mask).sum(dim=1)          # (B, H)
        mean_pooled = (summed / lengths).float().cpu().numpy()  # (B, H)

        all_embeddings.append(mean_pooled)

    return np.vstack(all_embeddings)


def get_reps_mean(df, model, tokenizer, device, batch_size=1, **kwargs):
    """
    Wrapper that returns a dataframe with a 'representation' column
    containing mean-pooled embeddings.
    kwargs are passed to get_hidden_states_mean.
    """
    embeddings = get_hidden_states_mean(
        df=df, model=model, tokenizer=tokenizer, device=device,
        batch_size=batch_size, **kwargs
    )
    annotated_df = df.copy()
    annotated_df['representation_mean'] = list(embeddings)
    return annotated_df




In [None]:
def make_template_format(df):
     df['question_f'] = df['question'].apply(lambda x : LLAMA3_CHAT_TEMPLATE.format(question = x))
     df['answer_f'] = df['answer'].apply(lambda x : x + tokenizer.eos_token)  
     return df

df = make_template_format(df)

In [None]:
### this is for pre unlearning representations

df = get_reps_mean(df=df, model=base_model, tokenizer=tokenizer, device=torch.device('cuda'), batch_size=16)
representations = np.stack(df['representation_mean'].values)
np.save('./reps/pre_ul_reps.npy', representations) # save this where you want

In [None]:
### this is for adaptor representations

df = get_reps_mean(df=df, model=model, tokenizer=tokenizer, device=torch.device('cuda'), batch_size=16)
representations = np.stack(df['representation_mean'].values)
np.save(f'./reps/{ds}_{algorithm}_reps.npy', representations) # save this where you want

In [4]:
def get_retain_sets(df, n_forget, k_values=[1, 2, 5], strategy='orthogonal'):
    """
    df: The dataframe with 'shift_score'
    n_forget: Size of the forget set (98)
    k_values: Multipliers [1, 2, 5]
    strategy: 'orthogonal' (lowest shift) or 'hard' (highest shift) or 'random'
    """
    sets = {}
    
    # Sort the dataframe once
    if strategy == 'orthogonal':
        # Sort ascending (Smallest shift first) -> Safest
        sorted_df = df.sort_values(by='shift_score', ascending=True)
    elif strategy == 'hard':
        # Sort descending (Largest shift first) -> Most Protective
        sorted_df = df.sort_values(by='shift_score', ascending=False)
    elif strategy == 'random':
        sorted_df = df.sample(frac=1, random_state=42) # Shuffle
        
    for k in k_values:
        count = n_forget * k
        selected_samples = sorted_df.head(count)
        sets[f'k{k}'] = selected_samples
        print(f"Strategy: {strategy} | k={k} | Count: {len(selected_samples)}")
        print(f"  -> Mean Shift: {selected_samples['shift_score'].mean():.4f}")
        
    return sets

In [6]:
def get_shift_score(df):
    matrix_pre = np.array(df['pre_ul'].tolist())
    matrix_post = np.array(df['post_ul'].tolist())
    diff_matrix = matrix_post - matrix_pre
    df['shift_score'] = np.linalg.norm(diff_matrix, axis=1)
    return df

In [7]:
def get_retain_shift_score(df, retain_path):
    retain = read_file(retain_path)
    print('retain shape is: ', retain.shape)
    retain_ids = retain['id'].tolist()
    retain_shift = df.loc[df['id'].isin(retain_ids)]
    return retain_shift

def get_forget_shift_score(df, forget_path):
    forget = read_file(forget_path)
    print('forget shape is: ', forget.shape)
    forget_ids = forget['id'].tolist()
    forget_shift = df.loc[df['id'].isin(forget_ids)]
    return forget_shift

In [None]:
## please check if you are loading the correct representation files (dataset and model wise)

pre_ul_reps = np.load('./reps/pre_ul_reps.npy')
post_ul_reps = np.load('./reps/{ds}_{algorithm}_reps.npy')

df['pre_ul'] = list(pre_ul_reps)
df['post_ul'] = list(post_ul_reps)

In [None]:
df = get_shift_score(df)
retain_shift = get_retain_shift_score(df, f'./data/datasets/retain_1.parquet') #check the path retain_1, retain_2 etc based on your data setup
forget_shift = get_forget_shift_score(df, './data/datasets/forget_1.parquet') #check the path

In [None]:
k = 1
n =  len(forget_shift) * k

retain_sets_ortho = get_retain_sets(retain_shift, n_forget=n, strategy='orthogonal') 
retain_sets_hard = get_retain_sets(retain_shift, n_forget=n, strategy='hard')

In [None]:
k1_ortho = retain_sets_ortho['k1']
k2_ortho = retain_sets_ortho['k2']
k5_ortho = retain_sets_ortho['k5']

k1_hard = retain_sets_hard['k1']
k2_hard = retain_sets_hard['k2']
k5_hard = retain_sets_hard['k5']

In [None]:
k1_ortho.to_parquet(f'{ds}_gd_ortho_1.parquet', index = False)
k2_ortho.to_parquet(f'{ds}_gd_ortho_2.parquet', index = False)
k5_ortho.to_parquet(f'{ds}_gd_ortho_5.parquet', index = False)


k1_hard.to_parquet(f'{ds}_gd_hard_1.parquet', index = False)
k2_hard.to_parquet(f'{ds}_gd_hard_2.parquet', index = False)
k5_hard.to_parquet(f'{ds}_gd_hard_5.parquet', index = False)