# LLM-Lasso: Spotify Model Ablations

In [None]:
from llm_lasso.data_splits import read_train_test_splits, read_baseline_splits
from llm_lasso.task_specific_lasso.llm_lasso import *
from llm_lasso.task_specific_lasso.plotting import plot_llm_lasso_result, plot_heatmap, \
    LLM_LASSO_COLORS, BASELINE_COLORS, LASSO_COLOR
import os
import json
import matplotlib.pyplot as plt
import pickle
import seaborn as sns

In [None]:
%load_ext autoreload
%autoreload 2

## Step 1: Command-Line Portion

Run the following in your command line
```
./shell_scripts/spotify/step_01_splits.sh

./shell_scripts/model_ablation_spotify/gpt_4o.sh

./shell_scripts/model_ablation_spotify/o1.sh

./shell_scripts/model_ablation_spotify/deepseek.sh

./shell_scripts/model_ablation_spotify/gpt3_5.sh

./shell_scripts/model_ablation_spotify/llama8b.sh

./shell_scripts/model_ablation_spotify/llama405b.sh
```

**Note**: Between when the results presented in the LLM-Lasso paper were collected and our submission, `qwen/qvq-72b-preview` stopped being openly available on `openrouter` (you need to request access now). So, this notebook does not include `qwen` results.

### Step 2: Evaluation

In [None]:
N_SPLITS = 10
DATASET="spotify"
BASE_FOLDER="../data/experiment-results"
os.makedirs(f"{BASE_FOLDER}/{DATASET}", exist_ok=True)
splits = read_train_test_splits(f"../data/splits/{DATASET}", N_SPLITS)

In [None]:
config = LLMLassoExperimentConfig(
    folds_cv=5,
    regression=True,
    max_features_for_baselines=30,
    n_threads=8,

    # Lasso config
    lambda_min_ratio=0.001,
    relaxed_lasso=False,
    lasso_downstream_l2=True,
    max_imp_power=2,

    run_pure_lasso_after=2,
    cross_val_metric=CrossValMetric.ERROR
)

In [None]:
RERUN_LASSO = True
EXPERIMENT_NAME = "logistic"

lasso_csv = f"{BASE_FOLDER}/{DATASET}/lasso_{EXPERIMENT_NAME}.csv"

if not RERUN_LASSO and os.path.exists(lasso_csv):
    print(f"CSV found at {lasso_csv}. Loading.")
    lasso = pd.read_csv(lasso_csv)
else:
    lasso = run_lasso_baseline_for_splits(
        splits=splits,
        config=config
    )
    lasso.to_csv(lasso_csv, index=False)
lasso["llm"] = "Lasso"

In [None]:
models = [
    "gpt-3.5-turbo-0613",
    "gpt-4o",
    "o1",
    "llama-3-8b-instruct",
    "llama-3.1-405b-instruct",
    "deepseek"
]

In [None]:
dataframes_plain = [lasso]

RERUN_LLM_LASSO = True

for model in models:
    print(model)
    with open(f"../data/llm-lasso/spotify/{model}/final_scores_plain.txt") as f:
        plain_scores = np.array([float(x) for x in f.readlines()])
    llm_lasso_csv = f"{BASE_FOLDER}/{DATASET}/llm_lasso_{model}.csv"

    if not RERUN_LLM_LASSO and os.path.exists(llm_lasso_csv):
        print(f"CSV found at {llm_lasso_csv}. Loading.")
        llm_lasso = pd.read_csv(llm_lasso_csv)
    else:
        llm_lasso = run_llm_lasso_cv_for_splits(
            splits=splits,
            scores={
                "plain": plain_scores,
            },
            config=config,
            verbose=False,
        )
        llm_lasso.to_csv(llm_lasso_csv, index=False)
    llm_lasso["llm"] = model
    dataframes_plain.append(llm_lasso[llm_lasso["method_model"] == "1/imp - plain"])

all_results_plain = pd.concat(dataframes_plain, ignore_index=True).copy()
all_results_plain = all_results_plain[all_results_plain["n_features"] == 5][["test_error", "auroc", "llm"]]

In [None]:
summary = (
    all_results_plain
    .groupby('llm', dropna=False)
    .agg(
        mean=('test_error', 'mean'),
        qlow=('test_error', lambda x: x.quantile(0.05)),
        qhigh=('test_error', lambda x: x.quantile(0.95)),
    ).reset_index()
)
argsort = summary.sort_values(by="mean").index.tolist()
colors =  ["#aaaaaa"] + LLM_LASSO_COLORS + BASELINE_COLORS
colors = [colors[i] for i in argsort]
summary = summary.sort_values(by="mean").reset_index(drop=True)

In [None]:
errors = np.zeros((2, summary.shape[0]))
errors[0, :] = summary["mean"] - summary["qlow"]
errors[1, :] = summary["qhigh"] - summary["mean"]

In [None]:
# Plot
plt.figure(figsize=(10, 4))
plt.grid(zorder=0)
barplot = sns.barplot(
    data=summary,
    x='llm',
    y='mean',
    hue="llm",
    palette=colors,
    alpha=0.8,
    edgecolor="black",
    errorbar=None,
    zorder=3
)

# Add error bars manually
plt.errorbar(
    x=range(len(summary)),
    y=summary['mean'],
    yerr=errors,
    fmt='none',
    c='black',
    capsize=5,
    zorder=5
)

# Customize
plt.title('Plain LLM-Lasso Test Error: Spotify (Model Ablation)', fontdict={"size": 21})
plt.xlabel(None)
plt.ylabel('Test Error', fontdict={"size": 18})
plt.xticks([])
plt.tick_params(axis='both', labelsize=14) 

labels = summary['llm'].tolist()
plt.legend(labels=labels, bbox_to_anchor=(1.05, 1), loc='upper left',
           fontsize=16)
# plt.tight_layout()
plt.show()