In [3]:
%cd ../../../

/Users/nseverin/MyData/Projects/Science/LLM/sasrec-bert4rec-recsys23


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [5]:
from collections import defaultdict
import re
import numpy as np
import pandas as pd

def open_text(filename): 
    with open(filename) as f:
        return f.read()


def parse_file_llm(log_data):
    # Regex patterns
    config_pattern = re.compile(r"experiments(.+?\.yaml)")
    epoch_pattern = re.compile(r"Epoch (\d+)/\d+, Loss: ([\d.]+)")
    metrics_pattern = re.compile(r"(Validation|Test) Metrics: ({.+?})")
    
    # Parse the log
    results = defaultdict(lambda: {"validation_metrics": [], "test_metrics": []})
    current_config = None
        
    for line in log_data.splitlines():
        config_match = config_pattern.search(line)
        if config_match:
            current_config = config_match.group(1)
        
        epoch_match = epoch_pattern.search(line)
        if epoch_match:
            epoch_num = int(epoch_match.group(1))
            loss = float(epoch_match.group(2))
        
        metrics_match = metrics_pattern.search(line)
        if metrics_match and current_config:
            metric_type = metrics_match.group(1).lower()  # "validation" or "test"
            metrics = eval(metrics_match.group(2))  # Convert string to dict
            results[current_config][f"{metric_type}_metrics"].append({"epoch": epoch_num, "loss": loss, **metrics})
    return results


def find_max(results, config, metric='NDCG@10'):
    def get_dct_epoch(lst, epoch):
        for x in lst:
            if x['epoch'] == epoch:
                return x
        return lst[-1]
        # raise Exception()
    
    def postprocess(lst):
        res = []
        for dct in lst:
            res.append({
                'NDCG@10': round(dct['NDCG@10'], 5),
                'Recall@10': round(dct['Recall@10'], 5),
                'epoch': dct['epoch']
            })
        return res
    
    best_val_dct = None
    for dct in results[config]['validation_metrics']:
        if best_val_dct is None:
            best_val_dct = dct
        elif dct[metric] > best_val_dct[metric]:
            best_val_dct = dct
    
    best_val_epoch = best_val_dct['epoch']
    
    return {
        'test': postprocess([get_dct_epoch(results[config]['test_metrics'], best_val_epoch)]),
        'validation': postprocess([best_val_dct]),
    }


PARAMS = ["config_file", "weighting_scheme", "alpha", "fine_tune_epoch", 'reconstruct_loss', "reconstruction_layer", 'weight_scale', 'use_down_scale','use_upscale','multi_profile','multi_profile_aggr_scheme','scale_guide_loss','user_profile_embeddings_files']


def create_dataframe(exp_data_with_max_val):
    # List to store extracted experiment data
    data = []
    
    # Parse experiment results
    for config_path, metrics in exp_data_with_max_val.items():
        # Extract config details from the filename
        config_file = config_path.split("/")[-1]  # Get only the filename
        config_name = config_file.replace(".yaml", "").lstrip('sasrec-').split("-")
        
        
        # {weighting_scheme}-{alpha}-{fine_tune_epoch}-{reconstruct_loss}-{reconstruction_layer}-{weight_scale}-{use_down_scale}-{use_upscale}-{multi_profile}-{multi_profile_aggr_scheme}-{scale_guide_loss}-{user_profile_embeddings_files}-{seed}.yaml
        weighting_scheme = (config_name[0])
        alpha = float(config_name[1])
        fine_tune_epoch = int(config_name[2])
        reconstruct_loss = (config_name[3])
        reconstruction_layer = int(config_name[4])
        weight_scale = float(config_name[5])
        use_down_scale = (config_name[6])
        use_upscale = (config_name[7])
        multi_profile = (config_name[8])
        multi_profile_aggr_scheme = (config_name[9])
        scale_guide_loss = (config_name[10])
        user_profile_embeddings_files = (config_name[11])
        seed = int(config_name[12])
        
        # Extract validation and test metrics
        val_metrics = metrics.get("validation", [{}])[0]
        test_metrics = metrics.get("test", [{}])[0]
    
        # Append extracted values to data list
        data.append([
            config_file, weighting_scheme, alpha, fine_tune_epoch, reconstruct_loss, reconstruction_layer, weight_scale, use_down_scale,use_upscale,multi_profile,multi_profile_aggr_scheme,scale_guide_loss,user_profile_embeddings_files,seed,
            val_metrics.get("NDCG@10", None), val_metrics.get("Recall@10", None), val_metrics.get("epoch", None),
            test_metrics.get("NDCG@10", None), test_metrics.get("Recall@10", None), test_metrics.get("epoch", None)
        ])
    
    # Create DataFrame
    columns = PARAMS + ['seed',
        "val_NDCG@10", "val_Recall@10", "val_epoch", "test_NDCG@10", "test_Recall@10", "test_epoch"
        ]
    
    df = pd.DataFrame(data, columns=columns)
    return df


def apply_seed_info(df):
    df['params'] = df.apply(lambda row: tuple([row[param] for param in PARAMS]), axis=1)
    params2seeds = {}
    for params, seed in zip(df['params'], df['seed']):
        if params not in params2seeds:
            params2seeds[params] = []
        params2seeds[params].append(seed)
    df['all_seeds'] = df.apply(lambda row: (params2seeds[row['params']]), axis=1)
    return df

In [6]:
len('{weighting_scheme}-{alpha}-{fine_tune_epoch}-{reconstruct_loss}-{reconstruction_layer}-{weight_scale}-{use_down_scale}-{use_upscale}-{multi_profile}-{multi_profile_aggr_scheme}-{scale_guide_loss}-{user_profile_embeddings_files}-{seed}.yaml'.split('-'))

13

In [7]:
MODEL_NAME = 'sasrec'
DATASET = 'beauty'
EXPERIMENT_NAME = 'BEAUTY_INITIAL'
SPLIT_NAME = 'general'
MODE = 'LLM'

local_directory = f"experiments-2_0/results/{MODEL_NAME}/{DATASET}/{EXPERIMENT_NAME}" 

seed_folders = ['single_seed', 'other_seed']

In [8]:
import os


results = {}
for seed_folder in seed_folders:
    cur_folder = os.path.join(local_directory, seed_folder)
    for file in os.listdir(cur_folder):
        log_data = open_text(f'{cur_folder}/{file}')
        cur_results = parse_file_llm(log_data)
        # if cur_results['simple']['validation_metrics'] == []:
        #     continue
        results.update(cur_results)
    
    
exp_data_with_max_val = {}
for config in results:
    exp_data_with_max_val[config] = find_max(results, config, metric='NDCG@10')
exp_data_with_max_val.keys()

dict_keys(['-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-True-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-True-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-False-mean-False-gemma_long_large_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-False-mean-False-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-Fal

In [9]:
exp_data_with_max_val['/configs/new_exps_2025_embs/beauty/sasrec-attention-0.8-12-RMSE-2-0.1-False-True-short-256.yaml']

KeyError: '/configs/new_exps_2025_embs/beauty/sasrec-attention-0.8-12-RMSE-2-0.1-False-True-short-256.yaml'

In [12]:
df = create_dataframe(exp_data_with_max_val)
df

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,multi_profile_aggr_scheme,scale_guide_loss,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch
0,attention-0.7-12-MSE-2-0.1-False-False-False-m...,ttention,0.7,12,MSE,2,0.1,False,False,False,mean,False,gemma_long_large_umap_single,42,0.02095,0.03895,22,0.00876,0.01795,22
1,attention-0.7-12-MSE-2-0.1-False-False-False-m...,ttention,0.7,12,MSE,2,0.1,False,False,False,mean,False,gemma_short_large_umap_single,42,0.02141,0.04147,21,0.00936,0.01962,21
2,attention-0.7-12-MSE-2-0.1-False-False-False-m...,ttention,0.7,12,MSE,2,0.1,False,False,False,mean,True,gemma_long_large_umap_single,42,0.02053,0.03820,23,0.00796,0.01652,23
3,attention-0.7-12-MSE-2-0.1-False-False-False-m...,ttention,0.7,12,MSE,2,0.1,False,False,False,mean,True,gemma_short_large_umap_single,42,0.02084,0.04103,20,0.00892,0.01908,20
4,attention-0.7-12-MSE-2-0.1-True-False-False-me...,ttention,0.7,12,MSE,2,0.1,True,False,False,mean,False,gemma_long_large_single,42,0.02426,0.04802,7,0.01037,0.02187,7
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
553,attention-0.5-12-RMSE-2-0.1-True-False-False-m...,ttention,0.5,12,RMSE,2,0.1,True,False,False,mean,False,gemma_long_small_single,42,0.02328,0.04564,11,0.01094,0.02277,11
554,attention-0.5-12-RMSE-2-0.1-True-False-False-m...,ttention,0.5,12,RMSE,2,0.1,True,False,False,mean,False,gemma_short_large_umap_single,42,0.02334,0.04697,10,0.01038,0.02205,10
555,attention-0.5-12-RMSE-2-0.1-True-False-False-m...,ttention,0.5,12,RMSE,2,0.1,True,False,False,mean,True,gemma_long_large_single,42,0.02094,0.03806,17,0.00981,0.02033,17
556,attention-0.5-12-RMSE-2-0.1-True-False-False-m...,ttention,0.5,12,RMSE,2,0.1,True,False,False,mean,True,gemma_long_large_umap_single,42,0.02323,0.04504,15,0.01068,0.02241,15


In [13]:
df = apply_seed_info(df)
df.sort_values(by='test_NDCG@10', ascending=False, inplace=True)
df

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,...,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch,params,all_seeds
426,mean-0.7-12-MSE-1-0.1-True-False-False-mean-Tr...,mean,0.7,12,MSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.02298,0.04593,16,0.01162,0.02514,16,(mean-0.7-12-MSE-1-0.1-True-False-False-mean-T...,[42]
11,attention-0.7-12-MSE-2-0.1-True-False-False-me...,ttention,0.7,12,MSE,2,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.02330,0.04802,14,0.01138,0.02413,14,(attention-0.7-12-MSE-2-0.1-True-False-False-m...,[42]
88,mean-0.7-6-RMSE-1-0.1-True-False-False-mean-Tr...,mean,0.7,6,RMSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.02265,0.04549,10,0.01128,0.02378,10,(mean-0.7-6-RMSE-1-0.1-True-False-False-mean-T...,[42]
283,mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tru...,mean,0.5,6,MSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.02301,0.04712,9,0.01121,0.02360,9,(mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tr...,[42]
373,attention-0.5-12-MSE-1-0.1-True-False-False-me...,ttention,0.5,12,MSE,1,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.02411,0.04906,14,0.01120,0.02330,14,(attention-0.5-12-MSE-1-0.1-True-False-False-m...,[42]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
73,mean-0.7-6-MSE-2-0.1-True-False-False-mean-Tru...,mean,0.7,6,MSE,2,0.1,True,False,False,...,gemma_long_large_single,42,0.01945,0.03776,19,0.00842,0.01712,19,(mean-0.7-6-MSE-2-0.1-True-False-False-mean-Tr...,[42]
257,mean-0.7-12-MSE-2-0.1-True-False-False-mean-Tr...,mean,0.7,12,MSE,2,0.1,True,False,False,...,gemma_long_large_single,42,0.01829,0.03687,24,0.00827,0.01694,24,(mean-0.7-12-MSE-2-0.1-True-False-False-mean-T...,[42]
234,mean-0.6-6-MSE-2-0.1-True-False-False-mean-Tru...,mean,0.6,6,MSE,2,0.1,True,False,False,...,gemma_long_small_single,42,0.02085,0.03820,21,0.00826,0.01718,21,(mean-0.6-6-MSE-2-0.1-True-False-False-mean-Tr...,[42]
259,mean-0.7-12-MSE-2-0.1-True-False-False-mean-Tr...,mean,0.7,12,MSE,2,0.1,True,False,False,...,gemma_long_small_single,42,0.01899,0.03479,24,0.00816,0.01670,24,(mean-0.7-12-MSE-2-0.1-True-False-False-mean-T...,[42]


In [14]:
def get_stats_seeds(df_all):
    def aggregate_mean(series):
        first_elem = list(series)[0]
        if isinstance(first_elem, str) or isinstance(first_elem, bool):
            return first_elem
        return series.mean()
        # print(df)
        # all_cols = set(df.columns)
        # df_mean = df.mean(numeric_only=True)
        # rest_cols = all_cols - set(df_mean.columns)
        # for col in rest_cols:
        #     df_mean[col] = list(df[col])[0]
        # return df_mean
    
    def aggregate_std(series):
        # print(series)
        first_elem = list(series)[0]
        if isinstance(first_elem, str) or isinstance(first_elem, bool):
            return first_elem
        if len(series) == 1:
            return 0
        return series.std()
        # all_cols = set(df.columns)
        # df_mean = df.std(numeric_only=True)
        # rest_cols = all_cols - set(df_mean.columns)
        # for col in rest_cols:
        #     df_mean[col] = list(df[col])[0]
        # return df_mean
    
    
    metric_cols = ["val_NDCG@10", "val_Recall@10", "val_epoch", "test_NDCG@10", "test_Recall@10"]
        
    grouped_df = df_all.drop(['config_file', 'all_seeds', 'seed'], axis=1).groupby('params')
    df_mean = grouped_df.agg(aggregate_mean)
    df_std = grouped_df.agg(aggregate_std)
    for col in metric_cols:
        df_mean[col + '__std'] = df_std[col]
    return df_mean


def reorder_cols(df):
    order = ['val_epoch', 'val_epoch__std'] + PARAMS[1:] + ['val_NDCG@10', 'val_NDCG@10__std', 'val_Recall@10', 'val_Recall@10__std', 'test_NDCG@10', 'test_NDCG@10__std', 'test_Recall@10', 'test_Recall@10__std']
    return df[order]


df_all = df[df['all_seeds'].apply(len) == 1]
df_final = get_stats_seeds(df_all)
df_final = reorder_cols(df_final)
df_final.sort_values(by='test_NDCG@10', ascending=False, inplace=True)
df_final

Unnamed: 0_level_0,val_epoch,val_epoch__std,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,...,scale_guide_loss,user_profile_embeddings_files,val_NDCG@10,val_NDCG@10__std,val_Recall@10,val_Recall@10__std,test_NDCG@10,test_NDCG@10__std,test_Recall@10,test_Recall@10__std
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(mean-0.7-12-MSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, mean, 0.7, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",16.0,0,mean,0.7,12.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02298,0,0.04593,0,0.01162,0,0.02514,0
"(attention-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, ttention, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",14.0,0,ttention,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02330,0,0.04802,0,0.01138,0,0.02413,0
"(mean-0.7-6-RMSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, mean, 0.7, 6, RMSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",10.0,0,mean,0.7,6.0,RMSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02265,0,0.04549,0,0.01128,0,0.02378,0
"(mean-0.5-6-MSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, mean, 0.5, 6, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",9.0,0,mean,0.5,6.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02301,0,0.04712,0,0.01121,0,0.02360,0
"(attention-0.5-12-MSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, ttention, 0.5, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",14.0,0,ttention,0.5,12.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02411,0,0.04906,0,0.01120,0,0.02330,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"(mean-0.7-6-MSE-2-0.1-True-False-False-mean-True-gemma_long_large_single-42.yaml, mean, 0.7, 6, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_single)",19.0,0,mean,0.7,6.0,MSE,2.0,0.1,True,False,...,True,gemma_long_large_single,0.01945,0,0.03776,0,0.00842,0,0.01712,0
"(mean-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_long_large_single-42.yaml, mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_single)",24.0,0,mean,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_long_large_single,0.01829,0,0.03687,0,0.00827,0,0.01694,0
"(mean-0.6-6-MSE-2-0.1-True-False-False-mean-True-gemma_long_small_single-42.yaml, mean, 0.6, 6, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_small_single)",21.0,0,mean,0.6,6.0,MSE,2.0,0.1,True,False,...,True,gemma_long_small_single,0.02085,0,0.03820,0,0.00826,0,0.01718,0
"(mean-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_long_small_single-42.yaml, mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_small_single)",24.0,0,mean,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_long_small_single,0.01899,0,0.03479,0,0.00816,0,0.01670,0


In [18]:
df_final['weighting_scheme'] = df_final['weighting_scheme'].apply(lambda x: x if x != 'ttention' else 'attention')
df_final

Unnamed: 0_level_0,val_epoch,val_epoch__std,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,...,scale_guide_loss,user_profile_embeddings_files,val_NDCG@10,val_NDCG@10__std,val_Recall@10,val_Recall@10__std,test_NDCG@10,test_NDCG@10__std,test_Recall@10,test_Recall@10__std
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(mean-0.7-12-MSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, mean, 0.7, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",16.0,0,mean,0.7,12.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02298,0,0.04593,0,0.01162,0,0.02514,0
"(attention-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, ttention, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",14.0,0,attention,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02330,0,0.04802,0,0.01138,0,0.02413,0
"(mean-0.7-6-RMSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, mean, 0.7, 6, RMSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",10.0,0,mean,0.7,6.0,RMSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02265,0,0.04549,0,0.01128,0,0.02378,0
"(mean-0.5-6-MSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, mean, 0.5, 6, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",9.0,0,mean,0.5,6.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02301,0,0.04712,0,0.01121,0,0.02360,0
"(attention-0.5-12-MSE-1-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml, ttention, 0.5, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",14.0,0,attention,0.5,12.0,MSE,1.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02411,0,0.04906,0,0.01120,0,0.02330,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"(mean-0.7-6-MSE-2-0.1-True-False-False-mean-True-gemma_long_large_single-42.yaml, mean, 0.7, 6, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_single)",19.0,0,mean,0.7,6.0,MSE,2.0,0.1,True,False,...,True,gemma_long_large_single,0.01945,0,0.03776,0,0.00842,0,0.01712,0
"(mean-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_long_large_single-42.yaml, mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_single)",24.0,0,mean,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_long_large_single,0.01829,0,0.03687,0,0.00827,0,0.01694,0
"(mean-0.6-6-MSE-2-0.1-True-False-False-mean-True-gemma_long_small_single-42.yaml, mean, 0.6, 6, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_small_single)",21.0,0,mean,0.6,6.0,MSE,2.0,0.1,True,False,...,True,gemma_long_small_single,0.02085,0,0.03820,0,0.00826,0,0.01718,0
"(mean-0.7-12-MSE-2-0.1-True-False-False-mean-True-gemma_long_small_single-42.yaml, mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, True, gemma_long_small_single)",24.0,0,mean,0.7,12.0,MSE,2.0,0.1,True,False,...,True,gemma_long_small_single,0.01899,0,0.03479,0,0.00816,0,0.01670,0


In [20]:
df_final[PARAMS[1:]].to_json('best_beauty.json', index=False, orient="records")

In [17]:
results.keys()

dict_keys(['-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-True-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-True-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-False-mean-False-gemma_long_large_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-False-mean-False-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-Fal

In [23]:
results['-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_long_large_umap_single-42.yaml']

{'validation_metrics': [{'epoch': 1,
   'loss': 10.828,
   'Recall@5': 0.004162330905306972,
   'NDCG@5': 0.0023390187078764875,
   'Recall@10': 0.006243496357960458,
   'NDCG@10': 0.00302255539974203,
   'Recall@20': 0.010851791288836034,
   'NDCG@20': 0.004198139363666408,
   'loss_recsys': np.float64(10.347532872800473),
   'loss_guide': np.float64(0.0)},
  {'epoch': 2,
   'loss': 5.2138,
   'Recall@5': 0.004013676230117437,
   'NDCG@5': 0.0024368647296024288,
   'Recall@10': 0.008324661810613945,
   'NDCG@10': 0.003826736836797046,
   'Recall@20': 0.014122194143005798,
   'NDCG@20': 0.0052561474396673165,
   'loss_recsys': np.float64(10.325713298938892),
   'loss_guide': np.float64(0.0)},
  {'epoch': 3,
   'loss': 4.8,
   'Recall@5': 0.006094841682770923,
   'NDCG@5': 0.003388410317555781,
   'Recall@10': 0.009216589861751152,
   'NDCG@10': 0.004382003442624782,
   'Recall@20': 0.015014122194143005,
   'NDCG@20': 0.005846106928100518,
   'loss_recsys': np.float64(10.315859370761448

In [22]:
df['config_file']

426    mean-0.7-12-MSE-1-0.1-True-False-False-mean-Tr...
11     attention-0.7-12-MSE-2-0.1-True-False-False-me...
88     mean-0.7-6-RMSE-1-0.1-True-False-False-mean-Tr...
283    mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tru...
373    attention-0.5-12-MSE-1-0.1-True-False-False-me...
                             ...                        
73     mean-0.7-6-MSE-2-0.1-True-False-False-mean-Tru...
257    mean-0.7-12-MSE-2-0.1-True-False-False-mean-Tr...
234    mean-0.6-6-MSE-2-0.1-True-False-False-mean-Tru...
259    mean-0.7-12-MSE-2-0.1-True-False-False-mean-Tr...
2      attention-0.7-12-MSE-2-0.1-False-False-False-m...
Name: config_file, Length: 558, dtype: object

In [1]:
PARAMS

NameError: name 'PARAMS' is not defined