In [None]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
from pathlib import Path
import sys


ROOT = Path.cwd().parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

In [None]:
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
from tqdm import tqdm
from config import Config2
from template import LLAMA3_CHAT_TEMPLATE

In [None]:
def get_hidden_states(df, model, tokenizer, device, batch_size=1):
    texts = (df['question_f'] + ' ' + df['answer_f']).tolist()
    all_embeddings = []

    model.eval()
    print('Now extracting hidden reps')

    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(
            batch_texts,
            padding = True,
            truncation = True,
            return_tensors = 'pt',
            max_length = 256
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        penultimate_hidden_states = outputs.hidden_states[-2]

        seq_lens = inputs['attention_mask'].sum(dim=-1) - 1

        batch_size_curr = penultimate_hidden_states.shape[0]

        last_token_embeddings = penultimate_hidden_states[
            torch.arange(batch_size_curr, device = model.device),
            seq_lens
        ].float().cpu().numpy()
        all_embeddings.append(last_token_embeddings)

    return np.vstack(all_embeddings)


def get_reps(
    df, model, tokenizer, device, batch_size=1):
    
    embeddings = get_hidden_states(df = df, model = model, tokenizer = tokenizer, device = device, batch_size = batch_size)

    annotated_df = df.copy()

    annotated_df['representation'] = list(embeddings)

    return annotated_df

In [None]:
def cluster_and_select_mod(
    annotated_df:pd.DataFrame,
    selection_percent : float,
    n_clusters:int,
    represent_col:str = 'representation'):

    if not (0 < selection_percent <= 1.0):
        raise ValueError("selection_percent must be between 0 and 1")
    
    if len(annotated_df) < n_clusters:
        raise ValueError("n_clusters must be less than the number of rows in df")

    embeddings = np.vstack(annotated_df[represent_col].tolist())

    kmeans = KMeans(n_clusters = n_clusters, random_state = 42, n_init = 'auto')
    cluster_labels = kmeans.fit_predict(embeddings)
    centroids = kmeans.cluster_centers_

    cluster_df = annotated_df.copy()
    cluster_df['cluster'] = cluster_labels


    # selecting samples closest to the median distance in each cluster

    selected_indices = []
    total_samples_to_select = int(len(cluster_df) * selection_percent)

    print('Selecting samples closest to the median distance in each cluster...')

    for i in range(n_clusters):
        indices_in_cluster = np.where(cluster_labels == i)[0]
        if len(indices_in_cluster) == 0:
            continue

        embeddings_in_cluster = embeddings[indices_in_cluster]
        cluster_centroid  = centroids[i].reshape(1,-1)
        distances = cdist(embeddings_in_cluster, cluster_centroid).flatten()
        median_distance = np.median(distances)

        distances_from_median = np.abs(distances - median_distance)
        sorted_indices = indices_in_cluster[np.argsort(distances_from_median)]

        proportion_of_cluster = len(indices_in_cluster) / len(cluster_df)
        num_to_select_from_cluster = int(total_samples_to_select * proportion_of_cluster)
        num_to_select_from_cluster = max(1, num_to_select_from_cluster) if len(indices_in_cluster) > 0 else 0

        selected_indices.extend(sorted_indices[:num_to_select_from_cluster])
    print(f"Targeted {total_samples_to_select} samples, selected {len(selected_indices)}.")
    mod_df = cluster_df.iloc[selected_indices].copy()

    return  cluster_df, mod_df

In [None]:
device = 'cuda'

In [None]:
def make_template_format(df):
     df['question_f'] = df['question'].apply(lambda x : LLAMA3_CHAT_TEMPLATE.format(question = x))
     df['answer_f'] = df['answer'].apply(lambda x : x + tokenizer.eos_token)  
     return df

### For WPU

In [None]:
cfg = Config()
df = pd.read_csv('./data/wpu_data/retain_100.csv')

In [None]:
df = make_template_format(df)

In [None]:
full_rep_df = get_reps(df=df, model=model, tokenizer=tokenizer, device=device, batch_size=1)

In [None]:
full_rep_df = full_rep_df[['title', 'question', 'answer', 'type', 'representation']]
full_rep_df.head()

df = full_rep_df.copy()
df.to_parquet('./data/wpu_data/coresets/moderate/dta_reps_moderate.parquet', index = False)

In [None]:
_, mod_1 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.01,
    n_clusters = 4,
)

_, mod_2 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.02,
    n_clusters = 4,
)

_, mod_5 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.05,
    n_clusters = 4,
)

_, mod_10 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.1,
    n_clusters = 4,
)

_, mod_20 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.2,
    n_clusters = 4,
)

In [None]:
mod_1.to_csv('./data/wpu_data/coresets/moderate/mod_1.csv', index=False)
mod_2.to_csv('./data/wpu_data/coresets/moderate/mod_2.csv', index=False)
mod_5.to_csv('./data/wpu_data/coresets/moderate/mod_5.csv', index=False)
mod_10.to_csv('./data/wpu_data/coresets/moderate/mod_10.csv', index=False)
mod_20.to_csv('./data/wpu_data/coresets/moderate/mod_20.csv', index=False)

### Mix

In [None]:
model_id = 'path/to/the/

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = model.to('cuda')

In [None]:
df = pd.read_csv('./data/mix/full_retain.csv')

In [None]:
def get_hidden_states(df, model, tokenizer, device, batch_size=1):
    texts = (df['question_f'] + ' ' + df['answer_f']).tolist()
    all_embeddings = []

    model.eval()
    print('Now extracting hidden reps')

    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(
            batch_texts,
            padding = True,
            truncation = True,
            return_tensors = 'pt',
            max_length = 512
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        penultimate_hidden_states = outputs.hidden_states[-2]

        seq_lens = inputs['attention_mask'].sum(dim=-1) - 1

        batch_size_curr = penultimate_hidden_states.shape[0]

        last_token_embeddings = penultimate_hidden_states[
            torch.arange(batch_size_curr, device = model.device),
            seq_lens
        ].float().cpu().numpy()
        all_embeddings.append(last_token_embeddings)

    return np.vstack(all_embeddings)


def get_reps(
    df, model, tokenizer, device, batch_size=1):
    
    embeddings = get_hidden_states(df = df, model = model, tokenizer = tokenizer, device = device, batch_size = batch_size)

    annotated_df = df.copy()

    annotated_df['representation'] = list(embeddings)

    return annotated_df

In [None]:
df = make_template_format(df)

In [None]:
full_rep_df = get_reps(df=df, model=model, tokenizer=tokenizer, device=device, batch_size=1)

In [None]:
full_rep_df = full_rep_df[['title', 'question', 'answer', 'type', 'representation']]
full_rep_df.head()

df = full_rep_df.copy()
df.to_parquet('./data/mix/coresets/moderate/dta_reps_moderate.parquet', index = False)

In [None]:
_, mod_1 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.01,
    n_clusters = 4,
)

_, mod_2 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.02,
    n_clusters = 4,
)

_, mod_5 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.05,
    n_clusters = 4,
)

_, mod_10 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.1,
    n_clusters = 4,
)

_, mod_20 = cluster_and_select_mod(
    annotated_df = full_rep_df,
    selection_percent = 0.2,
    n_clusters = 4,
)

In [None]:
mod_1.to_csv('./data/mix/coresets/moderate/mod_1.csv', index=False)
mod_2.to_csv('./data/mix/coresets/moderate/mod_2.csv', index=False)
mod_5.to_csv('./data/mix/coresets/moderate/mod_5.csv', index=False)
mod_10.to_csv('./data/mix/coresets/moderate/mod_10.csv', index=False)
mod_20.to_csv('./data/mix/coresets/moderate/mod_20.csv', index=False)