# 2023-09-04 11 - Selecting constraint levels.ipynb
For the final experiment runs, we want to do 3 seed runs, but first we need to decide on constraint levels so that we don't run forever.

## Fetching runs

In [110]:
import wandb
from math import isnan 
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import wandb
from tqdm import tqdm
# from cycler import cycler
import matplotlib as mpl
from IPython.display import Markdown, display
    
api = wandb.Api()
project = "Autoformer"
workspace = "alelab"

# Get our two main experiments so far
experiment_tags = ["e1_weather_10e_statinformed","e2_electricity_10e_statinformed"]

# get all runs that both: 1.  match any experiment tag and 2. are finished
runs = api.runs(f"{workspace}/{project}",
                {"$and": [
                    {"tags": {"$in": experiment_tags}},
                    {"state": "finished"}
                ]})

def tag_experiment(run):
    for tag in experiment_tags:
        if tag in run.tags:
            return tag
    return ''

all_runs = []
run_counter = 0
for run in tqdm(runs):
    run_counter += 1
    for split in ["train", "test"]:
        for metric in ["mse",]:
            pred_len = run.config["pred_len"]
            metrics = np.zeros(pred_len)
            for i in range(pred_len):
                run_dict = {**run.config}
                #run_dict["constraint_level"] = constraint_level
                run_dict[f"{metric}"] = run.summary[f"{metric}/{split}/{i}"]
                #run_dict[f"{metric}"] = run.summary.get(f"{metric}/{split}/{i}",run.summary.get(f"mse/{split}/{i}",np.nan)) #god forgive me for this line
                run_dict["step"]=i
                run_dict["epoch"]=run.summary["epoch"]
                run_dict["infeasible_rate"]=run.summary[f"infeasible_rate/{split}"]
                run_dict["split"] = split
                run_dict["run_id"] = run.id
                # Get either Constrained/ or ERM/ from the run name, then append model name.
                #print("run.name", run.name)
                #debug if ERM run
                run_dict["Algorithm"] = f"{run.name.split('/')[0]} {run.config['model']}"
                run_dict["sweep_id"] = run.sweep.id
                #print("Algorithm", run_dict["Algorithm"])

                # To better plot constrained vs ERM
                #TODO this is a hack while I consolidate the tags. 
                run_dict["type"] = "ERM" if run.config['dual_lr'] == 0 else "Constrained"
                run_dict["experiment_tag"] = tag_experiment(run)

                all_runs.append(run_dict)
print(f"Fetched {run_counter} runs")
df = pd.DataFrame(all_runs)
print(f"Total records: {(df.shape)}")

100%|██████████| 64/64 [00:01<00:00, 44.28it/s] 


Fetched 64 runs
Total records: (43008, 61)


There should be 3 runs per model, per pred length (3 x 2 x 4 = 24), plus ERM runs: 2 per model per pred_len (2 x 4 = 8) Total 32. 

In [111]:
df[['run_id',"sweep_id",'Algorithm','model','constraint_level','pred_len','epoch']].drop_duplicates().sort_values(["pred_len",'Algorithm','model',"constraint_level"])

Unnamed: 0,run_id,sweep_id,Algorithm,model,constraint_level,pred_len,epoch
15936,iij1xyjo,jk3es6p9,Electricity-StatInformed-10e-Constrained Autof...,Autoformer,0.157,96,10
15552,4a1mgu20,jk3es6p9,Electricity-StatInformed-10e-Constrained Autof...,Autoformer,0.169,96,10
15168,a38bewyi,jk3es6p9,Electricity-StatInformed-10e-Constrained Autof...,Autoformer,0.170,96,10
15744,wegbe0c2,jk3es6p9,Electricity-StatInformed-10e-Constrained Reformer,Reformer,0.157,96,10
15360,avemcc84,jk3es6p9,Electricity-StatInformed-10e-Constrained Reformer,Reformer,0.169,96,10
...,...,...,...,...,...,...,...
38688,n1dluqy8,viiqubln,StatInformed-10e Reformer,Reformer,0.912,720,10
18816,5gf925sy,plrzt70h,StatInformed-ERM-10e Autoformer,Autoformer,-1.000,720,5
32256,5qsivph7,fi1lnhc1,StatInformed-ERM-10e Autoformer,Autoformer,-1.000,720,10
16128,39cj5po7,plrzt70h,StatInformed-ERM-10e Reformer,Reformer,-1.000,720,10


## Separating into two datasets: 

In [112]:
weather = df[df.experiment_tag == "e1_weather_10e_statinformed"]
electricity = df[df.experiment_tag == "e2_electricity_10e_statinformed"]
display(weather.shape)
display(electricity.shape)

(21504, 61)

(21504, 61)

## Pivot table of constraint violations in constrained vs ERM per pred len, model, constraint

In [118]:
for dataset in [weather, electricity]:
    tp2=dataset.copy(deep=True)
    #tp2=weather.query('split=="test"').copy()
    # Compute 'infeasible amount' as the difference between the mse and the constraint_level, 0 if mse < constraint
    #tp2['mean_violation']=tp2.apply(lambda x: max(0,x['mse']-x['constraint_level']),axis=1)
    
    # Split into constrained and ERM
    tp2_constrained=tp2.query('type=="Constrained"').copy()
    tp2_erm_baseline=tp2.query('type=="ERM"').copy()

    distinct_constraints = tp2_constrained[['pred_len','model','constraint_level']].drop_duplicates()

    # crossjoin the ERM mse to all constraint levels
    tp2_erm_baseline_broadcasted = tp2_erm_baseline.drop(columns=['constraint_level']).merge(distinct_constraints,on=['pred_len','model'],how='left')

    display(tp2_erm_baseline_broadcasted[['type','model','constraint_level','mse']].head())

    # concat back together to compute mean_violation
    tp2=pd.concat([tp2_constrained,tp2_erm_baseline_broadcasted],axis=0)

    tp2['mean_violation']=tp2.apply(lambda x: max(0,x['mse']-x['constraint_level']),axis=1)

    # Sanity check show the mean_violation for the ERM runs
    display(tp2.query('type=="ERM"')[['type','model','constraint_level','mse','mean_violation']].head())

    grouped_values=tp2.groupby(['pred_len','constraint_level','model','epoch','type'])[['mse','mean_violation']].mean().reset_index()
    #grouped_constrained = grouped_values[grouped_values.type=='Constrained'].copy()
    #grouped_erm_baseline=grouped_values[grouped_values.type=='ERM'].copy()

    # best_constraints is the constraint for each model and pred_len that minimizes mean_violation. show min and argmin
    #best_constraints_per_window = grouped_values.query('type=="Constrained"').groupby(['pred_len']).apply(lambda x: x.loc[x['mean_violation'].idxmin()]).reset_index(drop=True)[['pred_len','type','constraint_level','mean_violation','mse']]
    # instead of best constraints, show all constraint levels.

    # join best_constraints and grouped_values to get the mse and mean_violation for the best constraint
    #best_constrained=grouped_values.merge(best_constraints_per_window[['pred_len','constraint_level']],on=['pred_len','constraint_level'],how='inner')


    # crossjoin the ERM mse to all constraint levels
    ##grouped_erm_baseline_broadcasted = grouped_erm_baseline.drop(columns=['constraint_level']).merge(grouped_constrained[['pred_len','model','constraint_level']].drop_duplicates(),on=['pred_len','model'],how='left')

    # Pivot of models versus constrained and ERM, values are mse and mse_diff
    all_results = grouped_values#pd.concat([grouped_constrained,grouped_erm_baseline_broadcasted],axis=0)
    # Pivot all results, sort ascending by pred_len, then mean_violation on constrained
    #display(grouped_constrained)

    #display(all_results)
    #.sort_values(['model','pred_len','constraint_level'],ascending=[True,True,True]) \
    pivoted=all_results.pivot(index=['pred_len','constraint_level'], columns=['type','model'], values=['mse','mean_violation']).query('constraint_level>0')
        
    display(pivoted)

Unnamed: 0,type,model,constraint_level,mse
0,ERM,Reformer,0.912,0.331805
1,ERM,Reformer,0.862,0.331805
2,ERM,Reformer,0.698,0.331805
3,ERM,Reformer,0.912,0.30806
4,ERM,Reformer,0.862,0.30806


Unnamed: 0,type,model,constraint_level,mse,mean_violation
0,ERM,Reformer,0.912,0.331805,0.0
1,ERM,Reformer,0.862,0.331805,0.0
2,ERM,Reformer,0.698,0.331805,0.0
3,ERM,Reformer,0.912,0.30806,0.0
4,ERM,Reformer,0.862,0.30806,0.0


Unnamed: 0_level_0,Unnamed: 1_level_0,mse,mse,mse,mse,mean_violation,mean_violation,mean_violation,mean_violation
Unnamed: 0_level_1,type,ERM,Constrained,ERM,Constrained,ERM,Constrained,ERM,Constrained
Unnamed: 0_level_2,model,Autoformer,Autoformer,Reformer,Reformer,Autoformer,Autoformer,Reformer,Reformer
pred_len,constraint_level,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3
96,0.516,0.436612,0.452793,0.39875,0.413915,0.013588,0.012527,0.012296,0.006349
96,0.553,0.436612,0.451352,0.39875,0.417641,0.001179,0.000517,0.003862,0.001528
96,0.556,0.436612,0.452094,0.39875,0.417982,0.000464,9e-05,0.003414,0.001298
192,0.553,0.548902,0.593456,0.458042,0.451292,0.05359,0.051218,0.033241,0.005131
192,0.567,0.548902,0.565525,0.458042,0.454524,0.045882,0.031538,0.0275,0.003795
192,0.665,0.548902,0.578781,0.458042,0.479605,0.01122,0.001391,0.000324,0.000264
336,0.682,0.626727,0.696768,0.540936,0.526838,0.022535,0.031421,0.048676,0.003928
336,0.741,0.626727,0.63235,0.540936,0.543223,0.008558,0.008165,0.024393,0.001666
336,0.97,0.626727,0.648039,0.540936,0.586993,2e-06,0.0,0.0,0.0
720,0.698,0.743079,0.766667,0.615023,0.641697,0.083799,0.078279,0.124449,0.075516


Unnamed: 0,type,model,constraint_level,mse
0,ERM,Reformer,0.246,0.240694
1,ERM,Reformer,0.23,0.240694
2,ERM,Reformer,0.215,0.240694
3,ERM,Reformer,0.246,0.219102
4,ERM,Reformer,0.23,0.219102


Unnamed: 0,type,model,constraint_level,mse,mean_violation
0,ERM,Reformer,0.246,0.240694,0.0
1,ERM,Reformer,0.23,0.240694,0.010694
2,ERM,Reformer,0.215,0.240694,0.025694
3,ERM,Reformer,0.246,0.219102,0.0
4,ERM,Reformer,0.23,0.219102,0.0


Unnamed: 0_level_0,Unnamed: 1_level_0,mse,mse,mse,mse,mean_violation,mean_violation,mean_violation,mean_violation
Unnamed: 0_level_1,type,Constrained,ERM,Constrained,ERM,Constrained,ERM,Constrained,ERM
Unnamed: 0_level_2,model,Autoformer,Autoformer,Reformer,Reformer,Autoformer,Autoformer,Reformer,Reformer
pred_len,constraint_level,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3
96,0.157,0.144695,0.143767,0.173552,0.177397,0.003594,0.005007,0.021704,0.023016
96,0.169,0.148177,0.143767,0.174374,0.177397,0.000978,0.000869,0.015946,0.016917
96,0.17,0.145579,0.143767,0.174441,0.177397,0.000683,0.000632,0.0155,0.016412
192,0.173,0.168626,0.168303,0.194665,0.198567,0.009551,0.008023,0.02999,0.031623
192,0.184,0.166888,0.168303,0.195759,0.198567,0.00333,0.004417,0.025068,0.026071
192,0.2,0.172916,0.168303,0.197867,0.198567,0.002806,0.001393,0.018213,0.018063
336,0.232,0.181441,0.205281,0.20454,0.202233,4.4e-05,0.008363,0.005495,0.005059
336,0.245,0.184334,0.205281,0.20632,0.202233,5e-06,0.004215,0.002383,0.002167
336,0.256,0.183049,0.205281,0.208139,0.202233,0.0,0.002104,0.000646,0.000558
720,0.215,0.224589,0.217492,0.216084,0.215306,0.025414,0.016884,0.01892,0.018912
