## Interactive subspaces

This notebook let you to gaze subspaces we learn with an interactive UMAP.

In [3]:
import json
import torch
import pandas as pd
import altair as alt
from umap import UMAP

In [4]:
def load_jsonl(filepath):
    """Example JSONL loading function."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def get_genre(metadata_entry):
    """Returns the genre string ('text', 'code', or 'math')."""
    concept_name = metadata_entry['concept']
    return metadata_entry['concept_genres_map'][concept_name][0]

In [5]:
lsreft = torch.load("../results/prod_2b_l20_concept16k_lsreft/train/LsReFT_weight.pt")
lsreft_metadata = load_jsonl("../results/prod_2b_l20_concept16k_lsreft/train/metadata.jsonl")
assert lsreft.shape[0] == len(lsreft_metadata)

diffmean = torch.load("../results/prod_2b_l20_concept16k_diffmean/train/DiffMean_weight.pt")
diffmean_metadata = load_jsonl("../results/prod_2b_l20_concept16k_diffmean/train/metadata.jsonl")
assert diffmean.shape[0] == len(diffmean_metadata)
assert lsreft.shape[0] == diffmean.shape[0]

In [7]:
umap_lsreft = UMAP(n_components=2, random_state=42).fit_transform(lsreft.float().numpy())
umap_diffmean = UMAP(n_components=2, random_state=42).fit_transform(diffmean.float().numpy())

df_lsreft = pd.DataFrame({
    'x': umap_lsreft[:, 0],
    'y': umap_lsreft[:, 1],
    'concept': [m['concept'] for m in lsreft_metadata],    # or any other field
    'genre':   [get_genre(m) for m in lsreft_metadata],
    'method':  'LsReFT'
})
df_diffmean = pd.DataFrame({
    'x': umap_diffmean[:, 0],
    'y': umap_diffmean[:, 1],
    'concept': [m['concept'] for m in diffmean_metadata],
    'genre':   [get_genre(m) for m in diffmean_metadata],
    'method':  'DiffMean'
})
df_all = pd.concat([df_lsreft, df_diffmean], ignore_index=True)

df_subsample = df_all.sample(n=2000, random_state=42)

alt.data_transformers.disable_max_rows()

chart = (
    alt.Chart(df_subsample)
    .mark_circle(size=30)
    .encode(
        x=alt.X('x:Q'),
        y=alt.Y('y:Q'),
        color=alt.Color('genre:N', legend=alt.Legend(title="Genre")),
        tooltip=['concept:N', 'method:N', 'genre:N']  # fields that appear on hover
    )
    .interactive()  # Allows panning & zooming
)

chart  # Display in Jupyter

  warn(
  warn(
