# LLM-Lasso: Large-Scale Lung Cancer Dataset

In [None]:
from llm_lasso.data_splits import read_train_test_splits, read_baseline_splits, save_train_test_splits
from llm_lasso.task_specific_lasso.llm_lasso import *
from llm_lasso.task_specific_lasso.plotting import plot_llm_lasso_result, plot_heatmap
import os
import json
import matplotlib.pyplot as plt
import pickle
import pandas as pd

In [None]:
%load_ext autoreload
%autoreload 2

## Step 1: Generate Data Splits

In [None]:
X = pd.read_csv("../data/Lung_TCGA/expression.csv")
with open("../data/Lung_TCGA/labels.txt", "r") as f:
    y = pd.Series([
        0 if line.strip() == "\"LUAD\"" else (1 if line.strip() == "\"LUSC\"" else 2) \
            for line in f.readlines()
    ])
assert not np.any(y == 2)

In [None]:
save_train_test_splits(X, y, "../data/splits/Lung_TCGA", balanced=True, n_splits=10, seed=42)

## Step 2: Command Line Portion

Run the following in your command line
```
./shell_scripts/Lung_TCGA/step_02_baselines.sh

./shell_scripts/Lung_TCGA/step_03_llm_score_baseline.sh
```

### Penalties
Scripts for running LLM-Lasso, are available in `shell_scripts/Lung_TCGA`. Note that setting up OMIM RAG is a bit time-intensive, but you can run plain LLM-Lasso without setting up OMIM RAG.

For setting up OMIM RAG, refer to **`examples/omim_rag_tutorial.ipynb`**.

## Step 3: Evaluation

In [None]:
# Load in splits
N_SPLITS = 10
splits = read_train_test_splits("../data/splits/Lung_TCGA", N_SPLITS)
n_features = splits[0].x_train.shape[1]

In [None]:
# Load in LLM-Lasso Penalties

with open("../data/Lung_TCGA/trial_scores_RAG.json") as f:
    trial_scores_rag = np.array([x["scores"][0] for x in json.load(f)]) + 2
penalties_rag = trial_scores_rag.mean(axis=0)

with open("../data/Lung_TCGA/trial_scores_plain.json") as f:
    trial_scores_plain = np.array([x["scores"][0] for x in json.load(f)]) + 2
penalties_plain = trial_scores_plain.mean(axis=0)

penalty_list={
    "plain": penalties_plain,
    "rag": penalties_rag,
}

In [None]:
# Load in baseline features
feature_baseline = read_baseline_splits(
    "../data/baselines/Lung_TCGA", n_splits=N_SPLITS, n_features=49)

with open("../data/llm-score/Lung_TCGA/llmselect_selected_features.json", "r") as f:
    llm_select_genes = json.load(f)[f"{50}"]

feature_baseline["llm_score"] = [llm_select_genes] * N_SPLITS

In [None]:
config = LLMLassoExperimentConfig(
    folds_cv=5, # number of cross-validation folds
    regression=False,
    score_type=PenaltyType.PF, # We have penalty factors from the LLM,
                               # not importance scores.
    max_imp_power=1,
    lambda_min_ratio=0.001, # Lasso parameter,
    n_threads=8, # number of threads to use for computation
    run_pure_lasso_after=10,
    lasso_downstream_l2=True,
    cross_val_metric=CrossValMetric.ERROR
)

In [None]:
baselines = run_downstream_baselines_for_splits(
    splits=splits,
    feature_baseline=feature_baseline,
    config=config
)

In [None]:
lasso = run_lasso_baseline_for_splits(
    splits=splits,
    config=config
)

In [None]:
llm_lasso = run_llm_lasso_cv_for_splits(
    splits=splits,
    scores=penalty_list,
    config=config,
    score_trial_list={
        "plain": trial_scores_plain,
        "rag": trial_scores_rag
    },
    verbose=False
)

In [None]:
dataframes_to_plot = [df[df["n_features"] > 0] for df in [lasso,llm_lasso]]
plot_llm_lasso_result(
    dataframes_to_plot,
    bolded_methods=["1/imp - rag"],
    plot_error_bars=False,
    test_error_y_lim=(0.04, 0.081),
    auroc_y_lim=(0.95, 0.99),
    x_lim=30
)

## Win Ratio Bar Plot
Direct comparison between RAG LLM-Lasso and Lasso in terms of how many points LLM-Lasso is strictly better than Lasso and vice versa.

In [None]:
N_FEAT = 30
all_results = pd.concat([
    lasso, llm_lasso[llm_lasso["method_model"] == "1/imp - rag"]
], ignore_index=True).copy()

In [None]:
for method_model in all_results["method_model"].unique():
    for split in range(N_SPLITS):
        prev_row = None
        for nfeat in range(N_FEAT+1):
            row = all_results[
                np.bitwise_and(
                    all_results["method_model"] == method_model,
                    np.bitwise_and(
                        all_results["split"] == split,
                        all_results["n_features"] == nfeat))
                ]
            if row.shape[0] == 1:
                prev_row = row.copy()
            elif row.shape[0] == 0:
                if prev_row is not None:
                    prev_row["n_features"] = nfeat
                    all_results = pd.concat([all_results, prev_row], ignore_index=True)

In [None]:
all_results = all_results[all_results["n_features"] <= N_FEAT]

In [None]:
method_counts = pd.DataFrame()

for s in range(10):
    split = all_results[all_results["split"] == s]
    split_reversed = split.iloc[::-1].reset_index(drop=True)
    best_methods = (
        split.loc[
            split.groupby('n_features')['test_error'].idxmin()
        ][['n_features', 'method_model', 'test_error']]
        .reset_index(drop=True)
    )
    best_methods_rev = (
        split_reversed.loc[
            split_reversed.groupby('n_features')['test_error'].idxmin()
        ][['n_features', 'method_model', 'test_error']]
        .reset_index(drop=True)
    )
    best_methods = best_methods[best_methods["method_model"] == best_methods_rev["method_model"]]
    # print(best_methods)
    df = best_methods['method_model'].value_counts().reset_index() 
    df["split"] = s
    method_counts = pd.concat([df, method_counts], axis=0).copy()

In [None]:
dataframe = method_counts.groupby("method_model").agg(
    mean=("count", "mean"),
    sd=("count", "std")
).reset_index()

In [None]:
barplot_x = "LUAD vs. LUSC"

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Sample data
our_mean = dataframe[dataframe["method_model"] == "1/imp - rag"]["mean"].tolist()[0] /  N_FEAT
lasso_mean = dataframe[dataframe["method_model"] == "Lasso"]["mean"].tolist()[0] / N_FEAT
our_sd = dataframe[dataframe["method_model"] == "1/imp - rag"]["sd"].tolist()[0] / N_FEAT 
lasso_sd = dataframe[dataframe["method_model"] == "Lasso"]["sd"].tolist()[0] / N_FEAT

# positions of the bars on the x-axis
x = np.arange(1)
width = 0.35  # width of the bars

# Create the figure and axes object
fig, ax = plt.subplots(figsize=(5, 5))
plt.grid(zorder=0)

# Plot bars for group 1
rects1 = ax.bar(
    x - width/2, [our_mean], width,
    label='RAG LLM-Lasso',
    color="#FE6100",
    alpha=0.8,
    edgecolor="black",
    zorder=3
)

# Plot bars for group 2
rects2 = ax.bar(
    x + width/2, [lasso_mean], width,
    label='Lasso',
    color="#aaaaaa",
    alpha=0.8,
    edgecolor="black",
    zorder=3
)

ax.errorbar(
    x=x - width/2,
    y=[our_mean],
    yerr=[our_sd],
    fmt='none',
    c='black',
    capsize=5,
    zorder=5
)
ax.errorbar(
    x=x + width/2,
    y=[lasso_mean],
    yerr=[lasso_sd],
    fmt='none',
    c='black',
    capsize=5,
    zorder=5
)

# Add some text for labels, title and axes ticks
ax.set_title(f'Win Ratio Over First {N_FEAT} Features', fontdict={"size": 22})
ax.set_xticks(x)
ax.set_xticklabels([barplot_x])
ax.legend(fontsize=16, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tick_params(axis='both', labelsize=14) 
# plt.box(False)

# fig.tight_layout()
plt.show()