In [None]:
from copy import deepcopy
import torch
from tqdm import tqdm, trange
from matplotlib.gridspec import GridSpec
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from devinterp.slt.forms import get_osculating_circle
from skimage.measure import EllipseModel
from icl.other.parity import MultitaskSparseParity, MLP
import wandb
from torch import nn, optim
from torch.utils.data import DataLoader 

EVAL_BATCH_SIZE = 256
BATCH_SIZE = 1024
NUM_STEPS = 2_500
DATASET_SIZE = NUM_STEPS * BATCH_SIZE
NUM_TASKS = 8
NUM_FEATURES = 16
NUM_TASK_BITS = 3
NUM_BITS = NUM_TASKS + NUM_FEATURES
HIDDEN_DIM = 50
ALPHA = .5
NUM_CHECKPOINTS = 1_000
LR=0.003
SEED = 0

log_interval=list(np.linspace(0, NUM_STEPS, NUM_CHECKPOINTS).astype(int)) 
torch.manual_seed(SEED)

dataset = MultitaskSparseParity(n=NUM_FEATURES, k=NUM_TASK_BITS, num_tasks=NUM_TASKS, alpha=ALPHA)

fig, ax = plt.subplots(figsize=(10, 5))

# Show task probabilities
ax.bar(range(NUM_TASKS), dataset.task_frequencies.detach().cpu().numpy())
ax.set_title(f"Task Probabilities ($\\alpha = {ALPHA}$)")

plt.show()

eval_sets = [dataset.generate_batch(EVAL_BATCH_SIZE, task_idx=t) for t in trange(NUM_TASKS, desc="Generating eval sets")]

model = MLP(input_dim=NUM_BITS, hidden_dim=HIDDEN_DIM, output_dim=2)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
# Log on exponential scale

log_wandb=True
wandb_entity="devinterp"
wandb_project_name="multitask-parity"
models = []

# Train the model
if log_wandb:
    wandb.init(entity=wandb_entity, project=wandb_project_name)
    wandb.config.update({
        "num_features": NUM_FEATURES,
        "num_tasks": NUM_TASKS,
        "hidden_dim": HIDDEN_DIM,
        "num_steps": NUM_STEPS,
        "batch_size": BATCH_SIZE,
        "eval_batch_size": EVAL_BATCH_SIZE,
        "alpha": ALPHA,
        "lr": LR,
        "num_checkpoints": NUM_CHECKPOINTS,
        "num_task_bits": NUM_TASK_BITS,
        "seed": SEED,
    })
    wandb.watch(model)
    

model.to('mps')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)

def accuracy_score(labels, predictions):
    return (labels == predictions).sum() / len(labels)

def eval_model(model, eval_sets):
    accuracies = []
    losses = []

    for task_idx, (inputs, labels) in enumerate(eval_sets):
        inputs = inputs.to('mps')
        labels = labels.to('mps')
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())
        _, predictions = torch.max(outputs, 1)
        accuracies.append(accuracy_score(labels, predictions).item())

    accuracies = np.array(accuracies)
    losses = np.array(losses)
    task_freqs = dataset.task_frequencies.detach().cpu().numpy()

    results = {
        "Loss": (losses @ task_freqs),
        "Accuracy": (accuracies @ task_freqs),
    }

    for task_idx in range(NUM_TASKS):
        results[f"Loss/{task_idx}"] = losses[task_idx]
        results[f"Accuracy/{task_idx}"] = accuracies[task_idx]

    return results

def log_to_wandb(step=None):
    with torch.no_grad():
        results = eval_model(model, eval_sets)
        wandb.log(results, step=step)

        # task_losses = [results[f"Loss/{task_idx}"] for task_idx in range(NUM_TASKS)]
        # task_accuracies = [results[f"Accuracy/{task_idx}"] for task_idx in range(NUM_TASKS)]

        # Create a line plot for task losses
        # loss_data = [[x, loss] for x, loss in enumerate(task_losses)]
        # loss_table = wandb.Table(data=loss_data, columns=["Task", "Loss"])
        # loss_plot = wandb.plot.line(loss_table, "Task", "Loss", title="Task Losses")

        # # Create a line plot for task accuracies
        # accuracy_data = [[x, accuracy] for x, accuracy in enumerate(task_accuracies)]
        # accuracy_table = wandb.Table(data=accuracy_data, columns=["Task", "Accuracy"])
        # accuracy_plot = wandb.plot.line(accuracy_table, "Task", "Accuracy", title="Task Accuracies")

        # Log the plots
        # wandb.log({"Task Losses": loss_plot, "Task Accuracies": accuracy_plot, **results}, step=step)
        # loss, accuracy = results["Loss"], results["Accuracy"]
        # wandb.log({"Task Losses": loss_plot, "Task Accuracies": accuracy_plot, "Loss": loss, "Accuracy": accuracy})

step = 0
log_to_wandb()

for i, (inputs, labels) in enumerate(tqdm(dataloader, total=NUM_STEPS, desc="Training")):
    if step > NUM_STEPS:
        break

    inputs = inputs.to('mps')
    labels = labels.to('mps')
    
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    if (step in log_interval) and log_wandb:
        log_to_wandb(step=step)  
        models.append(deepcopy(model.state_dict()))

    step += 1


In [None]:
import itertools
from sklearn.decomposition import PCA
from icl.constants import FIGURES
import plotly
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
import plotly.offline as pyo
from plotly.subplots import make_subplots
import itertools
from math import isnan
from sklearn.decomposition import PCA
import numpy as np
from icl.analysis.smoothing import gaussian_filter1d_variable_sigma

EVAL_SEED = 42
INIT_SMOOTHING = 0.1
FINAL_SMOOTHING = 30
# EVAL_SEED = 42
torch.manual_seed(EVAL_SEED)

ed_outputs = []
eval_inputs = torch.cat([x for x, _ in eval_sets], dim=0).to('mps')
eval_labels = torch.cat([y for _, y in eval_sets], dim=0).to('mps')

print(eval_inputs.shape, eval_labels.shape)

losses = []

eval_criterion = nn.CrossEntropyLoss(reduction='none')

for state_dict in tqdm(models, desc="Evaluating models"):
    model.load_state_dict(state_dict)
    preds = model(eval_inputs)
    loss = eval_criterion(preds, eval_labels)
    losses.append(loss.mean().item())
    ed_outputs.append(loss.cpu().detach().numpy().flatten())
    # ed_outputs.append(preds.cpu().detach().numpy().flatten())

ed_outputs_np = np.stack(ed_outputs)

NUM_COMPONENTS = 5
pca = PCA(n_components=NUM_COMPONENTS)

ed_projections = pca.fit_transform(ed_outputs_np)
ed_projections_smooth = gaussian_filter1d_variable_sigma(ed_projections, sigma=np.linspace(INIT_SMOOTHING, FINAL_SMOOTHING, ed_projections.shape[0]), axis=0)

print(pca.explained_variance_ratio_)

import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(rows=1, cols=1+NUM_COMPONENTS, subplot_titles=["Loss"] + [f"Component {i} ({pca.explained_variance_ratio_[i-1]:.2f})" for i in range(1, NUM_COMPONENTS+1)])

fig.add_trace(go.Scatter(x=log_interval, y=losses, mode="lines", name="Loss", showlegend=False), row=1, col=1)

for i in range(1, NUM_COMPONENTS+1):
    fig.add_trace(go.Scatter(x=log_interval, y=ed_projections[:, i-1], mode="markers", marker=dict(color='black', opacity=0.1), name=f"Component {i}", showlegend=False), row=1, col=i+1)
    fig.add_trace(go.Scatter(x=log_interval, y=ed_projections_smooth[:, i-1], mode="lines", line=dict(color='red', width=4), name=f"Component {i} (Smooth)", showlegend=False), row=1, col=i+1)
    fig.update_xaxes(title_text="Step", row=1, col=i+1)

fig.update_layout(height=500, width=2000, title_text="PCA Components")
fig.show()

wandb.log({"PCA/Components": fig}, step=step, commit=True)

# fig, axes = plt.subplots(1, 1+NUM_COMPONENTS, figsize=(20, 5))

# axes[0].plot(losses)
# axes[0].set_title("Loss")
# axes[0].set_xlabel("Step")

# for i in range(1, NUM_COMPONENTS+1):
#     axes[i].scatter(range(ed_projections.shape[0]), ed_projections[:, i-1], color='k', alpha=0.1)
#     axes[i].plot(ed_projections_smooth[:, i-1], color='r', lw=4)
#     axes[i].set_title(f"Component {i} ({pca.explained_variance_ratio_[i-1]:.2f})")
#     axes[i].set_xlabel("Step")

# plt.tight_layout()
# wandb.log({"PCA/Components": fig}, step=step)

In [None]:
task_references = []

for t in range(NUM_TASKS):
    task_reference = np.ones(EVAL_BATCH_SIZE * NUM_TASKS) * -np.log(0.5)
    task_reference[0:(t+1)* EVAL_BATCH_SIZE] = 0
    task_references.append(task_reference)
    
task_references = np.stack(task_references)
task_references_reduced = pca.transform(task_references)
task_references.shape, task_references_reduced.shape

In [None]:
form_potentials = np.zeros((NUM_TASKS, len(log_interval)))

for t in range(NUM_TASKS):
    form_potentials[t, :] = np.sum((ed_outputs_np - task_references[t, :]) ** 2, axis=1)

fig = make_subplots(rows=1, cols=NUM_TASKS, subplot_titles=[f"Task {i}" for i in range(1, NUM_TASKS+1)])

for i in range(NUM_TASKS):
    fig.add_trace(go.Scatter(x=log_interval, y=form_potentials[i], mode="lines", name=f"Task {i+1}", showlegend=False), row=1, col=i+1)

fig.update_layout(height=500, width=2000, title_text="Form Potentials")
fig.show()

wandb.log({"PCA/Form Potentials": fig}, commit=True)

In [None]:
def to_color_string(color):
    # return (256 * color[0], 256 * color[1], 256 * color[2], color[3])
    return f"rgb({int(256 * color[0])}, {int(256 * color[1])}, {int(256 * color[2])}, {color[3]})"


def plot_ed(pca, reduced, reduced_smooth, task_references, model_id, form_cmap='rainbow', evolute_cmap='Spectral', num_components=3, title="", slug="pca.html"):
    labels = {
        str(i): f"PC {i+1} ({var:.1f}%)"
        for i, var in enumerate(pca.explained_variance_ratio_ * 100)
    }

    subplot_titles = []
    fig = make_subplots(rows=num_components, cols=num_components, subplot_titles=subplot_titles)

    if isinstance(form_cmap, str):
        form_cmap = sns.color_palette(form_cmap, as_cmap=True)
    if isinstance(evolute_cmap, str):
        evolute_cmap = sns.color_palette(evolute_cmap, as_cmap=True)

    colors = np.array([to_color_string(form_cmap(c)) for c in np.linspace(0, 1, reduced.shape[0])])   
    evolute_colors = np.array([to_color_string(evolute_cmap(c)) for c in np.linspace(0, 1, len(reduced_smooth)-4)])

    for i, j in tqdm(itertools.product(range(num_components), range(num_components)), total=num_components ** 2): 
        row, col = i + 1, j + 1
            
        ymin, ymax = (
            reduced[:, i].min(),
            reduced[:, i].max(),
        )
        xmin, xmax = (
            reduced[:, j].min(),
            reduced[:, j].max(),
        )


        ts = np.array(range(2, len(reduced_smooth) - 2))
        centers = np.zeros((len(ts), 2))

        # Circles
        for ti, t in enumerate(ts):
            center, radius = get_osculating_circle(
                reduced_smooth[:, (j, i)], t
            )
            if ti % 3 == 0:
                # This seems to be cheaper than directly plotting a circle
                circle = go.Scatter(
                    x=center[0] + radius * np.cos(np.linspace(0, 2 * np.pi, 100)),
                    y=center[1] + radius * np.sin(np.linspace(0, 2 * np.pi, 100)),
                    mode="lines",
                    line=dict(color="rgba(0.1, 0.1, 1, 0.05)", width=1),
                    showlegend=False,
                )
                fig.add_trace(circle, row=row, col=col)

            centers[ti] = center

        # Centers
        fig.add_trace(
            go.Scatter(
                x=centers[:, 0],
                y=centers[:, 1],
                mode="markers",
                marker=dict(size=4, symbol="x", color=evolute_colors),
                name="Centers",
            ),
            row=row,
            col=col,
        )

        # Original samples
        fig.add_trace(
            go.Scatter(
                x=reduced[:, j],
                y=reduced[:, i],
                mode="markers",
                marker=dict(color=colors, size=3),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        # Smoothed trajectory
        fig.add_trace(
            go.Scatter(
                x=reduced_smooth[:, j],
                y=reduced_smooth[:, i],
                mode="lines",
                line=dict(color="black", width=2),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        # Task references
        fig.add_trace(
            go.Scatter(
                x=task_references_reduced[:, j],
                y=task_references_reduced[:, i],
                mode="markers",
                marker=dict(color="rgba(0, 0, 0, 1)", size=10),
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        if j == 0:
            fig.update_yaxes(title_text=labels[str(i)], row=row, col=col)

        fig.update_xaxes(title_text=labels[str(j)], row=row, col=col)

        fig.update_xaxes(
            range=(xmin * 1.25, xmax * 1.25),
            row=row,
            col=col,
        )
        fig.update_yaxes(
            range=(ymin * 1.25, ymax * 1.25),
            row=row,
            col=col,
        )

    fig.update_layout(width=2000, height=2000)  # Adjust the size as needed
    fig.update_layout(title_text=title, showlegend=False)

    # Save as html
    pyo.plot(fig, filename=str(FIGURES / model_id / f"{slug}.html"), auto_open=False)
    fig.write_image(str(FIGURES / model_id / f"{slug}.png"))

    return fig


if not (FIGURES / "multitask-parity").exists():
    (FIGURES / "multitask-parity").mkdir(parents=True)

fig = plot_ed(pca, ed_projections, ed_projections_smooth, task_references, "multitask-parity", title="Multitask Parity PCA", slug="pca", num_components=5)

# table = wandb.Table(columns=["figure"])
# table.add_data(wandb.Html(str(FIGURES / "multitask-parity" / "pca.html")))
# wandb.log({"PCA/Essential Dynamics": table}, commit=True)

In [None]:
wandb.finish()