In [1]:
%cd ../../../

/Users/nseverin/MyData/Projects/Science/LLM/sasrec-bert4rec-recsys23


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
from collections import defaultdict
import re
import numpy as np
import pandas as pd

def open_text(filename): 
    with open(filename) as f:
        return f.read()


def parse_file_llm(log_data):
    # Regex patterns
    config_pattern = re.compile(r"experiments(.+?\.yaml)")
    epoch_pattern = re.compile(r"Epoch (\d+)/\d+, Loss: ([\d.]+)")
    metrics_pattern = re.compile(r"(Validation|Test) Metrics: ({.+?})")
    
    # Parse the log
    results = defaultdict(lambda: {"validation_metrics": [], "test_metrics": []})
    current_config = None
        
    for line in log_data.splitlines():
        config_match = config_pattern.search(line)
        if config_match:
            current_config = config_match.group(1)
        
        epoch_match = epoch_pattern.search(line)
        if epoch_match:
            epoch_num = int(epoch_match.group(1))
            loss = float(epoch_match.group(2))
        
        metrics_match = metrics_pattern.search(line)
        if metrics_match and current_config:
            metric_type = metrics_match.group(1).lower()  # "validation" or "test"
            metrics = eval(metrics_match.group(2))  # Convert string to dict
            results[current_config][f"{metric_type}_metrics"].append({"epoch": epoch_num, "loss": loss, **metrics})
    return results


def find_max(results, config, metric='NDCG@10'):
    def get_dct_epoch(lst, epoch):
        for x in lst:
            if x['epoch'] == epoch:
                return x
        return lst[-1]
        # raise Exception()
    
    def postprocess(lst):
        res = []
        for dct in lst:
            res.append({
                'NDCG@10': round(dct['NDCG@10'], 5),
                'Recall@10': round(dct['Recall@10'], 5),
                'epoch': dct['epoch']
            })
        return res
    
    best_val_dct = None
    for dct in results[config]['validation_metrics']:
        if best_val_dct is None:
            best_val_dct = dct
        elif dct[metric] > best_val_dct[metric]:
            best_val_dct = dct
    
    best_val_epoch = best_val_dct['epoch']
    
    return {
        'test': postprocess([get_dct_epoch(results[config]['test_metrics'], best_val_epoch)]),
        'validation': postprocess([best_val_dct]),
    }


PARAMS = ["config_file", "weighting_scheme", "alpha", "fine_tune_epoch", 'reconstruct_loss', "reconstruction_layer", 'weight_scale', 'use_down_scale','use_upscale','multi_profile','multi_profile_aggr_scheme','scale_guide_loss','user_profile_embeddings_files']


def create_dataframe(exp_data_with_max_val):
    # List to store extracted experiment data
    data = []
    
    # Parse experiment results
    for config_path, metrics in exp_data_with_max_val.items():
        # Extract config details from the filename
        config_file = config_path.split("/")[-1]  # Get only the filename
        config_name = config_file.replace(".yaml", "").lstrip('sasrec-').split("-")
        
        
        # {weighting_scheme}-{alpha}-{fine_tune_epoch}-{reconstruct_loss}-{reconstruction_layer}-{weight_scale}-{use_down_scale}-{use_upscale}-{multi_profile}-{multi_profile_aggr_scheme}-{scale_guide_loss}-{user_profile_embeddings_files}-{seed}.yaml
        weighting_scheme = (config_name[0])
        alpha = float(config_name[1])
        fine_tune_epoch = int(config_name[2])
        reconstruct_loss = (config_name[3])
        reconstruction_layer = int(config_name[4])
        weight_scale = float(config_name[5])
        use_down_scale = (config_name[6])
        use_upscale = (config_name[7])
        multi_profile = (config_name[8])
        multi_profile_aggr_scheme = (config_name[9])
        scale_guide_loss = (config_name[10])
        user_profile_embeddings_files = (config_name[11])
        seed = int(config_name[12])
        
        # Extract validation and test metrics
        val_metrics = metrics.get("validation", [{}])[0]
        test_metrics = metrics.get("test", [{}])[0]
    
        # Append extracted values to data list
        data.append([
            config_file, weighting_scheme, alpha, fine_tune_epoch, reconstruct_loss, reconstruction_layer, weight_scale, use_down_scale,use_upscale,multi_profile,multi_profile_aggr_scheme,scale_guide_loss,user_profile_embeddings_files,seed,
            val_metrics.get("NDCG@10", None), val_metrics.get("Recall@10", None), val_metrics.get("epoch", None),
            test_metrics.get("NDCG@10", None), test_metrics.get("Recall@10", None), test_metrics.get("epoch", None)
        ])
    
    # Create DataFrame
    columns = PARAMS + ['seed',
        "val_NDCG@10", "val_Recall@10", "val_epoch", "test_NDCG@10", "test_Recall@10", "test_epoch"
        ]
    
    df = pd.DataFrame(data, columns=columns)
    return df


def apply_seed_info(df):
    df['params'] = df.apply(lambda row: tuple([row[param] for param in PARAMS[1:]]), axis=1)
    params2seeds = {}
    for params, seed in zip(df['params'], df['seed']):
        if params not in params2seeds:
            params2seeds[params] = []
        params2seeds[params].append(seed)
    df['all_seeds'] = df.apply(lambda row: (params2seeds[row['params']]), axis=1)
    return df

In [3]:
len('{weighting_scheme}-{alpha}-{fine_tune_epoch}-{reconstruct_loss}-{reconstruction_layer}-{weight_scale}-{use_down_scale}-{use_upscale}-{multi_profile}-{multi_profile_aggr_scheme}-{scale_guide_loss}-{user_profile_embeddings_files}-{seed}.yaml'.split('-'))

13

In [4]:
MODEL_NAME = 'sasrec'
DATASET = 'beauty'
EXPERIMENT_NAME = 'BEAUTY_50_50_single_new'
SPLIT_NAME = 'general'
MODE = 'LLM'

local_directory = f"experiments-2_0/results/{MODEL_NAME}/{DATASET}/{EXPERIMENT_NAME}" 

seed_folders = ['other_seed']

In [5]:
import os


results = {}
for seed_folder in seed_folders:
    cur_folder = os.path.join(local_directory, seed_folder)
    for file in os.listdir(cur_folder):
        log_data = open_text(f'{cur_folder}/{file}')
        cur_results = parse_file_llm(log_data)
        # if cur_results['simple']['validation_metrics'] == []:
        #     continue
        results.update(cur_results)
    
    
exp_data_with_max_val = {}
for config in results:
    exp_data_with_max_val[config] = find_max(results, config, metric='NDCG@10')
exp_data_with_max_val.keys()

dict_keys(['-2_0/configs/sasrec/beauty/BEAUTY_50_50_single_new/other_seed/mean-0.5-12-RMSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-1.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_50_50_single_new/other_seed/mean-0.5-12-RMSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-256.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_50_50_single_new/other_seed/mean-0.5-12-RMSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml'])

In [6]:
df = create_dataframe(exp_data_with_max_val)
df['weighting_scheme'] = df['weighting_scheme'].apply(lambda x: x if x != 'ttention' else 'attention')
df

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,multi_profile_aggr_scheme,scale_guide_loss,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch
0,mean-0.5-12-RMSE-1-0.1-True-False-False-mean-T...,mean,0.5,12,RMSE,1,0.1,True,False,False,mean,True,gemma_short_large_umap_single,1,0.01881,0.0382,15,0.00029,0.00065,15
1,mean-0.5-12-RMSE-1-0.1-True-False-False-mean-T...,mean,0.5,12,RMSE,1,0.1,True,False,False,mean,True,gemma_short_large_umap_single,256,0.02063,0.03984,18,0.00038,0.00101,18
2,mean-0.5-12-RMSE-1-0.1-True-False-False-mean-T...,mean,0.5,12,RMSE,1,0.1,True,False,False,mean,True,gemma_short_large_umap_single,42,0.00278,0.00624,2,0.00015,0.00036,2


In [21]:
df['weighting_scheme'].value_counts()

weighting_scheme
mean         35
attention    19
Name: count, dtype: int64

In [22]:
df = apply_seed_info(df)
df.sort_values(by='test_NDCG@10', ascending=False, inplace=True)

In [23]:
df

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,...,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch,params,all_seeds
2,attention-0.5-12-MSE-1-0.1-True-False-False-me...,attention,0.5,12,MSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.0235,0.04653,20,0.01078,0.02271,20,"(attention, 0.5, 12, MSE, 1, 0.1, True, False,...","[1, 256, 42]"
47,mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tru...,mean,0.5,6,MSE,1,0.1,True,False,False,...,gemma_long_large_umap_single,42,0.02271,0.04445,16,0.0107,0.02259,16,"(mean, 0.5, 6, MSE, 1, 0.1, True, False, False...","[1, 256, 42]"
50,mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tru...,mean,0.5,6,MSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.02316,0.04608,14,0.01067,0.02223,14,"(mean, 0.5, 6, MSE, 1, 0.1, True, False, False...","[1, 256, 42]"
16,mean-0.6-12-RMSE-1-0.1-True-False-False-mean-T...,mean,0.6,12,RMSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.0223,0.04445,19,0.01066,0.02253,19,"(mean, 0.6, 12, RMSE, 1, 0.1, True, False, Fal...","[1, 256, 42]"
32,mean-0.5-12-MSE-1-0.1-True-False-False-mean-Tr...,mean,0.5,12,MSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,256,0.0236,0.04772,19,0.01066,0.02253,19,"(mean, 0.5, 12, MSE, 1, 0.1, True, False, Fals...","[1, 256, 42]"
25,mean-0.7-6-RMSE-1-0.1-True-False-False-mean-Tr...,mean,0.7,6,RMSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,1,0.02322,0.04802,12,0.01065,0.02294,12,"(mean, 0.7, 6, RMSE, 1, 0.1, True, False, Fals...","[1, 256]"
40,attention-0.5-6-RMSE-2-0.1-True-False-False-me...,attention,0.5,6,RMSE,2,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.02335,0.04638,14,0.01065,0.02259,14,"(attention, 0.5, 6, RMSE, 2, 0.1, True, False,...","[1, 256, 42]"
48,mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tru...,mean,0.5,6,MSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,1,0.02341,0.04638,15,0.01063,0.02241,15,"(mean, 0.5, 6, MSE, 1, 0.1, True, False, False...","[1, 256, 42]"
39,attention-0.5-6-RMSE-2-0.1-True-False-False-me...,attention,0.5,6,RMSE,2,0.1,True,False,False,...,gemma_short_large_umap_single,256,0.02498,0.04965,14,0.01062,0.023,14,"(attention, 0.5, 6, RMSE, 2, 0.1, True, False,...","[1, 256, 42]"
45,mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tru...,mean,0.5,6,MSE,1,0.1,True,False,False,...,gemma_long_large_umap_single,1,0.02199,0.04237,15,0.0106,0.02265,15,"(mean, 0.5, 6, MSE, 1, 0.1, True, False, False...","[1, 256, 42]"


In [49]:
# df.to_csv('beauty_llm_all_runs.csv')

In [17]:
def get_stats_seeds(df_all):
    def aggregate_mean(series):
        first_elem = list(series)[0]
        if isinstance(first_elem, str) or isinstance(first_elem, bool):
            return first_elem
        return series.mean()
        # print(df)
        # all_cols = set(df.columns)
        # df_mean = df.mean(numeric_only=True)
        # rest_cols = all_cols - set(df_mean.columns)
        # for col in rest_cols:
        #     df_mean[col] = list(df[col])[0]
        # return df_mean
    
    def aggregate_std(series):
        # print(series)
        first_elem = list(series)[0]
        if isinstance(first_elem, str) or isinstance(first_elem, bool):
            return first_elem
        if len(series) == 1:
            return 0
        return series.std()
        # all_cols = set(df.columns)
        # df_mean = df.std(numeric_only=True)
        # rest_cols = all_cols - set(df_mean.columns)
        # for col in rest_cols:
        #     df_mean[col] = list(df[col])[0]
        # return df_mean
    
    
    metric_cols = ["val_NDCG@10", "val_Recall@10", "val_epoch", "test_NDCG@10", "test_Recall@10"]
        
    grouped_df = df_all.drop(['config_file', 'all_seeds', 'seed'], axis=1).groupby('params')
    df_mean = grouped_df.agg(aggregate_mean)
    df_std = grouped_df.agg(aggregate_std)
    for col in metric_cols:
        df_mean[col + '__std'] = df_std[col]
    return df_mean


def reorder_cols(df):
    order = ['val_epoch', 'val_epoch__std'] + PARAMS[1:] + ['val_NDCG@10', 'val_NDCG@10__std', 'val_Recall@10', 'val_Recall@10__std', 'test_NDCG@10', 'test_NDCG@10__std', 'test_Recall@10', 'test_Recall@10__std']
    return df[order]


df_all = df[df['all_seeds'].apply(len) == 3]
df_final = get_stats_seeds(df_all)
df_final = reorder_cols(df_final)
df_final.sort_values(by='test_NDCG@10', ascending=False, inplace=True)
df_final

Unnamed: 0_level_0,val_epoch,val_epoch__std,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,...,scale_guide_loss,user_profile_embeddings_files,val_NDCG@10,val_NDCG@10__std,val_Recall@10,val_Recall@10__std,test_NDCG@10,test_NDCG@10__std,test_Recall@10,test_Recall@10__std
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(attention, 0.5, 6, RMSE, 2, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",15.666667,2.081666,attention,0.5,6.0,RMSE,2.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.023837,0.000643,0.046973,0.002085,0.010387,0.0002,0.021877,0.000687
"(mean, 0.5, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",20.333333,3.21455,mean,0.5,12.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.023343,0.000246,0.045343,0.002629,0.01034,0.000255,0.021597,0.000809
"(mean, 0.5, 6, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",15.0,1.0,mean,0.5,6.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.022893,0.00062,0.04504,0.002245,0.010303,0.000116,0.021877,0.00027
"(mean, 0.5, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_long_large_umap_single)",19.0,2.0,mean,0.5,12.0,MSE,1.0,0.1,True,False,...,True,gemma_long_large_umap_single,0.023167,0.000309,0.04633,0.000966,0.010223,0.000579,0.021597,0.00136
"(attention, 0.5, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_umap_single)",18.333333,1.154701,attention,0.5,12.0,MSE,2.0,0.1,True,False,...,True,gemma_long_large_umap_single,0.023683,0.000391,0.04727,0.00026,0.01019,0.000192,0.021697,0.000475
"(attention, 0.5, 6, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",15.666667,2.081666,attention,0.5,6.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.023853,0.000224,0.046927,0.001299,0.01018,0.000201,0.021693,0.000567
"(mean, 0.6, 12, RMSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",19.666667,3.05505,mean,0.6,12.0,RMSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.023137,0.000397,0.046033,0.001994,0.010173,0.000474,0.02134,0.000777
"(attention, 0.5, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",19.0,2.0,attention,0.5,12.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.023563,0.000167,0.047073,0.000764,0.010077,0.000627,0.021363,0.001222
"(mean, 0.5, 6, MSE, 2, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",14.0,1.732051,mean,0.5,6.0,MSE,2.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.023403,0.000362,0.04638,0.00119,0.010043,0.000311,0.020943,0.000595
"(mean, 0.5, 6, MSE, 1, 0.1, True, False, False, mean, True, gemma_long_large_umap_single)",13.0,2.645751,mean,0.5,6.0,MSE,1.0,0.1,True,False,...,True,gemma_long_large_umap_single,0.023093,8.1e-05,0.045687,0.001691,0.01004,0.00063,0.02118,0.001797


In [11]:
df_final['weighting_scheme'].value_counts()

weighting_scheme
mean         9
attention    5
Name: count, dtype: int64

In [43]:
df_final[df_final['scale_guide_loss'] == 'False']

Unnamed: 0_level_0,val_epoch,val_epoch__std,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,...,scale_guide_loss,user_profile_embeddings_files,val_NDCG@10,val_NDCG@10__std,val_Recall@10,val_Recall@10__std,test_NDCG@10,test_NDCG@10__std,test_Recall@10,test_Recall@10__std
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(mean, 0.5, 6, MSE, 2, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",9.0,0.0,mean,0.5,6.0,MSE,2.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.02356,0.000114,0.045833,0.000167,0.010613,0.000415,0.022747,0.000685
"(mean, 0.5, 6, RMSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",8.666667,0.57735,mean,0.5,6.0,RMSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.02321,0.000332,0.04519,0.001808,0.010583,0.000435,0.022547,0.001066
"(attention, 0.7, 6, MSE, 1, 0.1, True, False, False, mean, False, gemma_long_large_single)",8.333333,1.527525,attention,0.7,6.0,MSE,1.0,0.1,True,False,...,False,gemma_long_large_single,0.023493,0.000313,0.046527,0.001419,0.01058,0.000401,0.022507,0.000856
"(attention, 0.5, 6, RMSE, 1, 0.1, True, False, False, mean, False, gemma_long_small_single)",9.666667,0.57735,attention,0.5,6.0,RMSE,1.0,0.1,True,False,...,False,gemma_long_small_single,0.023437,0.000418,0.04638,0.000397,0.010577,0.000172,0.02255,0.000308
"(mean, 0.5, 6, MSE, 2, 0.1, True, False, False, mean, False, gemma_long_large_umap_single)",8.666667,0.57735,mean,0.5,6.0,MSE,2.0,0.1,True,False,...,False,gemma_long_large_umap_single,0.023687,0.000535,0.046283,0.001561,0.010537,0.000355,0.02247,0.000334
"(mean, 0.7, 12, RMSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",9.0,0.0,mean,0.7,12.0,RMSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.023563,0.000221,0.046527,0.000391,0.010507,0.000483,0.022727,0.000745
"(attention, 0.5, 6, MSE, 2, 0.1, True, False, False, mean, False, gemma_long_small_single)",10.0,1.0,attention,0.5,6.0,MSE,2.0,0.1,True,False,...,False,gemma_long_small_single,0.023577,0.000717,0.045637,0.001853,0.010477,0.000424,0.02239,0.001345
"(attention, 0.5, 12, RMSE, 1, 0.1, True, False, False, mean, False, gemma_long_large_umap_single)",10.333333,0.57735,attention,0.5,12.0,RMSE,1.0,0.1,True,False,...,False,gemma_long_large_umap_single,0.02354,0.000215,0.046977,0.001179,0.010443,0.000346,0.02239,0.000711
"(attention, 0.7, 12, MSE, 1, 0.1, True, False, False, mean, False, gemma_long_large_single)",8.333333,1.527525,attention,0.7,12.0,MSE,1.0,0.1,True,False,...,False,gemma_long_large_single,0.023597,0.000616,0.047523,0.001644,0.010443,0.00017,0.02243,0.000606
"(attention, 0.6, 12, MSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",9.666667,0.57735,attention,0.6,12.0,MSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.023347,0.000211,0.04544,0.000312,0.01044,0.000178,0.02245,0.00033


In [40]:
df['params'].value_counts()

params
(mean, 0.7, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)         3
(mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, False, gemma_long_large_umap_single)         3
(mean, 0.6, 6, RMSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_umap_single)          3
(attention, 0.5, 6, RMSE, 1, 0.1, True, False, False, mean, False, gemma_long_large_umap_single)    3
(mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, False, gemma_long_large_single)              3
                                                                                                   ..
(mean, 0.6, 12, MSE, 1, 0.1, True, False, False, mean, False, gemma_long_large_umap_single)         1
(mean, 0.7, 6, RMSE, 1, 0.1, False, False, False, mean, True, gemma_long_large_umap_single)         1
(attention, 0.5, 6, MSE, 2, 0.1, True, False, False, mean, False, gemma_long_large_single)          1
(attention, 0.6, 6, MSE, 2, 0.1, False, False, False, mean, True, gemma_lon

Unnamed: 0_level_0,val_epoch,val_epoch__std,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,...,scale_guide_loss,user_profile_embeddings_files,val_NDCG@10,val_NDCG@10__std,val_Recall@10,val_Recall@10__std,test_NDCG@10,test_NDCG@10__std,test_Recall@10,test_Recall@10__std
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(mean-0.7-12-MSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, mean, 0.7, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",16.0,0,mean,0.7,12.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02298,0,0.04593,0,0.01162,0,0.02514,0
"(attention-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, ttention, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",14.0,0,attention,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02330,0,0.04802,0,0.01138,0,0.02413,0
"(mean-0.5-12-RMSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-256.yaml, mean, 0.5, 12, RMSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",14.0,0,mean,0.5,12.0,RMSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02293,0,0.04608,0,0.01134,0,0.02431,0
"(mean-0.7-6-RMSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, mean, 0.7, 6, RMSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",10.0,0,mean,0.7,6.0,RMSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02265,0,0.04549,0,0.01128,0,0.02378,0
"(mean-0.5-12-RMSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-1.yaml, mean, 0.5, 12, RMSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",14.0,0,mean,0.5,12.0,RMSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02346,0,0.04727,0,0.01123,0,0.02336,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"(mean-0.7-6-MSE-2-0.1-True-False-False-mean-True-gemma_long_large_single-42.yaml, mean, 0.7, 6, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_single)",19.0,0,mean,0.7,6.0,MSE,2.0,0.1,True,False,...,True,gemma_long_large_single,0.01945,0,0.03776,0,0.00842,0,0.01712,0
"(mean-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_long_large_single-42.yaml, mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_single)",24.0,0,mean,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_long_large_single,0.01829,0,0.03687,0,0.00827,0,0.01694,0
"(mean-0.6-6-MSE-2-0.1-True-False-False-mean-True-gemma_long_small_single-42.yaml, mean, 0.6, 6, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_small_single)",21.0,0,mean,0.6,6.0,MSE,2.0,0.1,True,False,...,True,gemma_long_small_single,0.02085,0,0.03820,0,0.00826,0,0.01718,0
"(mean-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_long_small_single-42.yaml, mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_small_single)",24.0,0,mean,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_long_small_single,0.01899,0,0.03479,0,0.00816,0,0.01670,0


In [42]:
df_final[PARAMS[1:]].to_json('best_beauty_3seeds.json', index=False, orient="records")

In [17]:
results.keys()

dict_keys(['-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-True-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-True-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-False-mean-False-gemma_long_large_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-False-mean-False-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-Fal

In [23]:
results['-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_long_large_umap_single-42.yaml']

{'validation_metrics': [{'epoch': 1,
   'loss': 10.828,
   'Recall@5': 0.004162330905306972,
   'NDCG@5': 0.0023390187078764875,
   'Recall@10': 0.006243496357960458,
   'NDCG@10': 0.00302255539974203,
   'Recall@20': 0.010851791288836034,
   'NDCG@20': 0.004198139363666408,
   'loss_recsys': np.float64(10.347532872800473),
   'loss_guide': np.float64(0.0)},
  {'epoch': 2,
   'loss': 5.2138,
   'Recall@5': 0.004013676230117437,
   'NDCG@5': 0.0024368647296024288,
   'Recall@10': 0.008324661810613945,
   'NDCG@10': 0.003826736836797046,
   'Recall@20': 0.014122194143005798,
   'NDCG@20': 0.0052561474396673165,
   'loss_recsys': np.float64(10.325713298938892),
   'loss_guide': np.float64(0.0)},
  {'epoch': 3,
   'loss': 4.8,
   'Recall@5': 0.006094841682770923,
   'NDCG@5': 0.003388410317555781,
   'Recall@10': 0.009216589861751152,
   'NDCG@10': 0.004382003442624782,
   'Recall@20': 0.015014122194143005,
   'NDCG@20': 0.005846106928100518,
   'loss_recsys': np.float64(10.315859370761448

In [22]:
df['config_file']

426    mean-0.7-12-MSE-1-0.1-True-False-False-mean-Tr...
11     attention-0.7-12-MSE-2-0.1-True-False-False-me...
88     mean-0.7-6-RMSE-1-0.1-True-False-False-mean-Tr...
283    mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tru...
373    attention-0.5-12-MSE-1-0.1-True-False-False-me...
                             ...                        
73     mean-0.7-6-MSE-2-0.1-True-False-False-mean-Tru...
257    mean-0.7-12-MSE-2-0.1-True-False-False-mean-Tr...
234    mean-0.6-6-MSE-2-0.1-True-False-False-mean-Tru...
259    mean-0.7-12-MSE-2-0.1-True-False-False-mean-Tr...
2      attention-0.7-12-MSE-2-0.1-False-False-False-m...
Name: config_file, Length: 558, dtype: object

In [1]:
PARAMS

NameError: name 'PARAMS' is not defined

In [47]:
df_final.to_csv('beauty_llm_runs.csv')