In [None]:
def get_hidden_states(df, model, tokenizer, device, batch_size=1, max_length = 512):
    texts = (df['question_f'] + ' ' + df['answer_f']).tolist()
    all_embeddings = []

    model.eval()
    print('Now extracting hidden reps')

    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(
            batch_texts,
            padding = True,
            truncation = True,
            return_tensors = 'pt',
            max_length = max_length
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        penultimate_hidden_states = outputs.hidden_states[-2]

        seq_lens = inputs['attention_mask'].sum(dim=-1) - 1

        batch_size_curr = penultimate_hidden_states.shape[0]

        last_token_embeddings = penultimate_hidden_states[
            torch.arange(batch_size_curr, device = model.device),
            seq_lens
        ].float().cpu().numpy()
        all_embeddings.append(last_token_embeddings)

    return np.vstack(all_embeddings)


def get_reps(
    df, model, tokenizer, device, batch_size=1):
    
    embeddings = get_hidden_states(df = df, model = model, tokenizer = tokenizer, device = device, batch_size = batch_size)

    annotated_df = df.copy()

    annotated_df['representation'] = list(embeddings)

    return annotated_df

In [None]:
def cluster_and_select_mod(
    annotated_df:pd.DataFrame,
    selection_percent : float,
    n_clusters:int,
    represent_col:str = 'representation'):

    if not (0 < selection_percent <= 1.0):
        raise ValueError("selection_percent must be between 0 and 1")
    
    if len(annotated_df) < n_clusters:
        raise ValueError("n_clusters must be less than the number of rows in df")

    embeddings = np.vstack(annotated_df[represent_col].tolist())

    kmeans = KMeans(n_clusters = n_clusters, random_state = 42, n_init = 'auto')
    cluster_labels = kmeans.fit_predict(embeddings)
    centroids = kmeans.cluster_centers_

    cluster_df = annotated_df.copy()
    cluster_df['cluster'] = cluster_labels


    # selecting samples closest to the median distance in each cluster

    selected_indices = []
    total_samples_to_select = int(len(cluster_df) * selection_percent)

    print('Selecting samples closest to the median distance in each cluster...')

    for i in range(n_clusters):
        indices_in_cluster = np.where(cluster_labels == i)[0]
        if len(indices_in_cluster) == 0:
            continue

        embeddings_in_cluster = embeddings[indices_in_cluster]
        cluster_centroid  = centroids[i].reshape(1,-1)
        distances = cdist(embeddings_in_cluster, cluster_centroid).flatten()
        median_distance = np.median(distances)

        distances_from_median = np.abs(distances - median_distance)
        sorted_indices = indices_in_cluster[np.argsort(distances_from_median)]

        proportion_of_cluster = len(indices_in_cluster) / len(cluster_df)
        num_to_select_from_cluster = int(total_samples_to_select * proportion_of_cluster)
        num_to_select_from_cluster = max(1, num_to_select_from_cluster) if len(indices_in_cluster) > 0 else 0

        selected_indices.extend(sorted_indices[:num_to_select_from_cluster])
    print(f"Targeted {total_samples_to_select} samples, selected {len(selected_indices)}.")
    mod_df = cluster_df.iloc[selected_indices].copy()

    return  cluster_df, mod_df

In [None]:
def make_template_format(df):
     df['question_f'] = df['question'].apply(lambda x : LLAMA3_CHAT_TEMPLATE.format(question = x))
     df['answer_f'] = df['answer'].apply(lambda x : x + tokenizer.eos_token)  
     return df

In [None]:
def cluster_and_select_fixed(
    annotated_df: pd.DataFrame,
    total_samples: int,
    n_clusters: int,
    represent_col: str = 'representation'):
    """
    Cluster embeddings and select a fixed number of samples closest to median distance.
    
    Parameters:
    -----------
    annotated_df : pd.DataFrame
        DataFrame containing the representation column
    total_samples : int
        Total number of samples to select across all clusters
    n_clusters : int
        Number of clusters to create
    represent_col : str
        Column name containing the embeddings/representations
        
    Returns:
    --------
    cluster_df : pd.DataFrame
        Original DataFrame with cluster labels added
    mod_df : pd.DataFrame
        Selected samples DataFrame
    """
    
    if total_samples <= 0:
        raise ValueError("total_samples must be greater than 0")
    
    if total_samples > len(annotated_df):
        raise ValueError("total_samples cannot exceed the number of rows in df")
    
    if len(annotated_df) < n_clusters:
        raise ValueError("n_clusters must be less than or equal to the number of rows in df")

    embeddings = np.vstack(annotated_df[represent_col].tolist())

    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
    cluster_labels = kmeans.fit_predict(embeddings)
    centroids = kmeans.cluster_centers_

    cluster_df = annotated_df.copy()
    cluster_df['cluster'] = cluster_labels

    # Select samples closest to the median distance in each cluster
    selected_indices = []
    samples_per_cluster = total_samples // n_clusters
    remainder = total_samples % n_clusters

    print('Selecting samples closest to the median distance in each cluster...')

    for i in range(n_clusters):
        indices_in_cluster = np.where(cluster_labels == i)[0]
        if len(indices_in_cluster) == 0:
            continue

        embeddings_in_cluster = embeddings[indices_in_cluster]
        cluster_centroid = centroids[i].reshape(1, -1)
        distances = cdist(embeddings_in_cluster, cluster_centroid).flatten()
        median_distance = np.median(distances)

        distances_from_median = np.abs(distances - median_distance)
        sorted_indices = indices_in_cluster[np.argsort(distances_from_median)]

        # Distribute remainder samples to first few clusters
        num_to_select = samples_per_cluster + (1 if i < remainder else 0)
        # Don't select more than available in cluster
        num_to_select = min(num_to_select, len(indices_in_cluster))

        selected_indices.extend(sorted_indices[:num_to_select])

    print(f"Targeted {total_samples} samples, selected {len(selected_indices)}.")
    print(f"Samples per cluster: {samples_per_cluster}, with {remainder} clusters getting 1 extra sample.")
    
    mod_df = cluster_df.iloc[selected_indices].copy()

    return cluster_df, mod_df

In [None]:
df = load the data
full_rep_df = get_reps(df=df, model=model, tokenizer=tokenizer, device=device, batch_size=1)
cluster_df, moderate_1, = cluster_and_select_fixed(
    annotated_df = df,
    total_samples = 98,
    n_clusters = 4,
    represent_col = 'representation')