In [1]:
%cd ../../../

/Users/nseverin/MyData/Projects/Science/LLM/sasrec-bert4rec-recsys23


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
from collections import defaultdict
import re
import numpy as np
import pandas as pd

def open_text(filename): 
    with open(filename) as f:
        return f.read()


def parse_file_llm(log_data):
    # Regex patterns
    config_pattern = re.compile(r"experiments(.+?\.yaml)")
    epoch_pattern = re.compile(r"Epoch: (\d+)/\d+ \| Train Total Loss: ([\d.]+)")
    metrics_pattern = re.compile(r"(Validation|Test) Metrics: ({.+?})")
    
    # Parse the log
    results = defaultdict(lambda: {"validation_metrics": [], "test_metrics": []})
    current_config = None
        
    for line in log_data.splitlines():
        config_match = config_pattern.search(line)
        if config_match:
            current_config = config_match.group(1)
        
        epoch_match = epoch_pattern.search(line)
        if epoch_match:
            print(epoch_match)
            epoch_num = int(epoch_match.group(1))
            loss = float(epoch_match.group(2))
        
        metrics_match = metrics_pattern.search(line)
        if metrics_match and current_config:
            metric_type = metrics_match.group(1).lower()  # "validation" or "test"
            metrics = eval(metrics_match.group(2))  # Convert string to dict
            results[current_config][f"{metric_type}_metrics"].append({"epoch": epoch_num, "loss": loss, **metrics})
    return results


def find_max(results, config, metric='NDCG@10'):
    def get_dct_epoch(lst, epoch):
        for x in lst:
            if x['epoch'] == epoch:
                return x
        return lst[-1]
        # raise Exception()
    
    def postprocess(lst):
        res = []
        for dct in lst:
            res.append({
                'NDCG@10': round(dct['NDCG@10'], 5),
                'Recall@10': round(dct['Recall@10'], 5),
                'epoch': dct['epoch']
            })
        return res
    
    best_val_dct = None
    for dct in results[config]['validation_metrics']:
        if best_val_dct is None:
            best_val_dct = dct
        elif dct[metric] > best_val_dct[metric]:
            best_val_dct = dct
    
    best_val_epoch = best_val_dct['epoch']
    
    return {
        'test': postprocess([get_dct_epoch(results[config]['test_metrics'], best_val_epoch)]),
        'validation': postprocess([best_val_dct]),
    }


PARAMS = ["config_file", "weighting_scheme", "alpha", "fine_tune_epoch", 'reconstruct_loss', "reconstruction_layer", 'weight_scale', 'use_down_scale','use_upscale','multi_profile','multi_profile_aggr_scheme','scale_guide_loss','user_profile_embeddings_files']


def create_dataframe(exp_data_with_max_val):
    # List to store extracted experiment data
    data = []
    
    # Parse experiment results
    for config_path, metrics in exp_data_with_max_val.items():
        # Extract config details from the filename
        config_file = config_path.split("/")[-1]  # Get only the filename
        config_name = config_file.replace(".yaml", "").replace('--1', '-4').lstrip('sasrec-').split("-")
        
        
        # {weighting_scheme}-{alpha}-{fine_tune_epoch}-{reconstruct_loss}-{reconstruction_layer}-{weight_scale}-{use_down_scale}-{use_upscale}-{multi_profile}-{multi_profile_aggr_scheme}-{scale_guide_loss}-{user_profile_embeddings_files}-{seed}.yaml
        weighting_scheme = (config_name[0])
        alpha = float(config_name[1])
        fine_tune_epoch = int(float(config_name[2]))
        reconstruct_loss = (config_name[3])
        reconstruction_layer = int(config_name[4])
        weight_scale = float(config_name[5])
        use_down_scale = (config_name[6])
        use_upscale = (config_name[7])
        multi_profile = (config_name[8])
        multi_profile_aggr_scheme = (config_name[9])
        scale_guide_loss = (config_name[10])
        user_profile_embeddings_files = (config_name[11])
        seed = int(config_name[12])
        
        # Extract validation and test metrics
        val_metrics = metrics.get("validation", [{}])[0]
        test_metrics = metrics.get("test", [{}])[0]
    
        # Append extracted values to data list
        data.append([
            config_file, weighting_scheme, alpha, fine_tune_epoch, reconstruct_loss, reconstruction_layer, weight_scale, use_down_scale,use_upscale,multi_profile,multi_profile_aggr_scheme,scale_guide_loss,user_profile_embeddings_files,seed,
            val_metrics.get("NDCG@10", None), val_metrics.get("Recall@10", None), val_metrics.get("epoch", None),
            test_metrics.get("NDCG@10", None), test_metrics.get("Recall@10", None), test_metrics.get("epoch", None)
        ])
    
    # Create DataFrame
    columns = PARAMS + ['seed',
        "val_NDCG@10", "val_Recall@10", "val_epoch", "test_NDCG@10", "test_Recall@10", "test_epoch"
        ]
    
    df = pd.DataFrame(data, columns=columns)
    return df


def apply_seed_info(df):
    df['params'] = df.apply(lambda row: tuple([row[param] for param in PARAMS[1:]]), axis=1)
    params2seeds = {}
    for params, seed in zip(df['params'], df['seed']):
        if params not in params2seeds:
            params2seeds[params] = []
        params2seeds[params].append(seed)
    df['all_seeds'] = df.apply(lambda row: (params2seeds[row['params']]), axis=1)
    return df

In [3]:
len('{weighting_scheme}-{alpha}-{fine_tune_epoch}-{reconstruct_loss}-{reconstruction_layer}-{weight_scale}-{use_down_scale}-{use_upscale}-{multi_profile}-{multi_profile_aggr_scheme}-{scale_guide_loss}-{user_profile_embeddings_files}-{seed}.yaml'.split('-'))

13

In [4]:
MODEL_NAME = 'sasrec'
DATASET = 'beauty'
EXPERIMENT_NAME = 'BEAUTY_history_30_70'
SEED_FOLDER = 'other_seed'
SPLIT_NAME = 'history_30_70'
MODE = 'LLM'


local_directory = f"experiments-2_0/results/{MODEL_NAME}/{DATASET}/{EXPERIMENT_NAME}" 

seed_folders = ['other_seed']

In [5]:
import os


results = {}
for seed_folder in seed_folders:
    cur_folder = os.path.join(local_directory, seed_folder)
    for file in os.listdir(cur_folder):
        log_data = open_text(f'{cur_folder}/{file}')
        cur_results = parse_file_llm(log_data)
        # if cur_results['simple']['validation_metrics'] == []:
        #     continue
        results.update(cur_results)
    
    
exp_data_with_max_val = {}
for config in results:
    exp_data_with_max_val[config] = find_max(results, config, metric='NDCG@10')
exp_data_with_max_val.keys()

<re.Match object; span=(0, 39), match='Epoch: 1/30 | Train Total Loss: 10.7847'>
<re.Match object; span=(0, 39), match='Epoch: 2/30 | Train Total Loss: 10.7006'>
<re.Match object; span=(0, 39), match='Epoch: 3/30 | Train Total Loss: 10.6162'>
<re.Match object; span=(0, 39), match='Epoch: 4/30 | Train Total Loss: 10.5121'>
<re.Match object; span=(0, 39), match='Epoch: 5/30 | Train Total Loss: 10.3856'>
<re.Match object; span=(0, 39), match='Epoch: 6/30 | Train Total Loss: 10.2048'>
<re.Match object; span=(0, 38), match='Epoch: 7/30 | Train Total Loss: 9.9741'>
<re.Match object; span=(0, 38), match='Epoch: 8/30 | Train Total Loss: 9.7169'>
<re.Match object; span=(0, 38), match='Epoch: 9/30 | Train Total Loss: 9.3564'>
<re.Match object; span=(0, 39), match='Epoch: 10/30 | Train Total Loss: 8.9398'>
<re.Match object; span=(0, 39), match='Epoch: 11/30 | Train Total Loss: 8.5903'>
<re.Match object; span=(0, 40), match='Epoch: 12/30 | Train Total Loss: 10.2759'>
<re.Match object; span=(0, 39)

dict_keys(['-2_0/configs/sasrec/beauty/BEAUTY_history_30_70/other_seed/mean-0.3-12-MSE-3-0.1-True-False-False-mean-True-gemma_short_large_single-1.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_history_30_70/other_seed/mean-0.3-12-MSE-3-0.1-True-False-False-mean-True-gemma_short_large_single-256.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_history_30_70/other_seed/mean-0.3-12-MSE-3-0.1-True-False-False-mean-True-gemma_short_large_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_history_30_70/other_seed/mean-0.3-12-MSE-3-0.1-True-False-False-mean-True-gemma_short_large_umap_single-1.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_history_30_70/other_seed/mean-0.3-12-MSE-3-0.1-True-False-False-mean-True-gemma_short_large_umap_single-256.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_history_30_70/other_seed/mean-0.3-12-MSE-3-0.1-True-False-False-mean-True-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_history_30_70/other_seed/mean-0.3-12-RMSE-3-0.1-True-False-False-mean-True-

In [6]:
df = create_dataframe(exp_data_with_max_val)
df['weighting_scheme'] = df['weighting_scheme'].apply(lambda x: x if x != 'ttention' else 'attention')
df['weighting_scheme'] = df['weighting_scheme'].apply(lambda x: x if x != 'xponential' else 'exponential')
df

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,multi_profile_aggr_scheme,scale_guide_loss,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch
0,mean-0.3-12-MSE-3-0.1-True-False-False-mean-Tr...,mean,0.3,12,MSE,3,0.1,True,False,False,mean,True,gemma_short_large_single,1,0.02350,0.04638,25,0.01089,0.02283,25
1,mean-0.3-12-MSE-3-0.1-True-False-False-mean-Tr...,mean,0.3,12,MSE,3,0.1,True,False,False,mean,True,gemma_short_large_single,256,0.02357,0.04697,22,0.01052,0.02158,22
2,mean-0.3-12-MSE-3-0.1-True-False-False-mean-Tr...,mean,0.3,12,MSE,3,0.1,True,False,False,mean,True,gemma_short_large_single,42,0.02433,0.04861,21,0.00962,0.01926,21
3,mean-0.3-12-MSE-3-0.1-True-False-False-mean-Tr...,mean,0.3,12,MSE,3,0.1,True,False,False,mean,True,gemma_short_large_umap_single,1,0.02354,0.04623,24,0.01017,0.02116,24
4,mean-0.3-12-MSE-3-0.1-True-False-False-mean-Tr...,mean,0.3,12,MSE,3,0.1,True,False,False,mean,True,gemma_short_large_umap_single,256,0.02383,0.04742,23,0.01064,0.02182,23
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
91,exponential-0.6-12-RMSE-3-0.1-True-False-False...,exponential,0.6,12,RMSE,3,0.1,True,False,False,mean,True,gemma_short_large_single,256,0.00036,0.00074,2,0.00002,0.00006,2
92,exponential-0.6-12-RMSE-3-0.1-True-False-False...,exponential,0.6,12,RMSE,3,0.1,True,False,False,mean,True,gemma_short_large_single,42,0.00096,0.00193,5,0.00040,0.00077,5
93,exponential-0.6-12-RMSE-3-0.1-True-False-False...,exponential,0.6,12,RMSE,3,0.1,True,False,False,mean,True,gemma_short_large_umap_single,1,0.00022,0.00045,1,0.00004,0.00012,1
94,exponential-0.6-12-RMSE-3-0.1-True-False-False...,exponential,0.6,12,RMSE,3,0.1,True,False,False,mean,True,gemma_short_large_umap_single,256,0.02275,0.04489,19,0.00851,0.01670,19


In [7]:
df['weighting_scheme'].value_counts()

weighting_scheme
mean           48
exponential    48
Name: count, dtype: int64

In [8]:
df = apply_seed_info(df)
df.sort_values(by='val_NDCG@10', ascending=False, inplace=True)
df

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,...,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch,params,all_seeds
14,mean-0.4-12-MSE-3-0.1-True-False-False-mean-Tr...,mean,0.4,12,MSE,3,0.1,True,False,False,...,gemma_short_large_single,42,0.02497,0.04891,21,0.01021,0.02039,21,"(mean, 0.4, 12, MSE, 3, 0.1, True, False, Fals...","[1, 256, 42]"
73,exponential-0.5-12-MSE-3-0.1-True-False-False-...,exponential,0.5,12,MSE,3,0.1,True,False,False,...,gemma_short_large_single,256,0.02493,0.04995,22,0.01074,0.02176,22,"(exponential, 0.5, 12, MSE, 3, 0.1, True, Fals...","[1, 256, 42]"
77,exponential-0.5-12-MSE-3-0.1-True-False-False-...,exponential,0.5,12,MSE,3,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.02487,0.04787,23,0.01093,0.02152,23,"(exponential, 0.5, 12, MSE, 3, 0.1, True, Fals...","[1, 256, 42]"
70,exponential-0.4-12-RMSE-3-0.1-True-False-False...,exponential,0.4,12,RMSE,3,0.1,True,False,False,...,gemma_short_large_umap_single,256,0.02482,0.05025,26,0.01039,0.02158,26,"(exponential, 0.4, 12, RMSE, 3, 0.1, True, Fal...","[1, 256, 42]"
26,mean-0.5-12-MSE-3-0.1-True-False-False-mean-Tr...,mean,0.5,12,MSE,3,0.1,True,False,False,...,gemma_short_large_single,42,0.02474,0.04906,22,0.01033,0.02122,22,"(mean, 0.5, 12, MSE, 3, 0.1, True, False, Fals...","[1, 256, 42]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
45,mean-0.6-12-RMSE-3-0.1-True-False-False-mean-T...,mean,0.6,12,RMSE,3,0.1,True,False,False,...,gemma_short_large_umap_single,1,0.00020,0.00045,1,0.00008,0.00024,1,"(mean, 0.6, 12, RMSE, 3, 0.1, True, False, Fal...","[1, 256, 42]"
71,exponential-0.4-12-RMSE-3-0.1-True-False-False...,exponential,0.4,12,RMSE,3,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.00017,0.00045,3,0.00020,0.00053,3,"(exponential, 0.4, 12, RMSE, 3, 0.1, True, Fal...","[1, 256, 42]"
83,exponential-0.5-12-RMSE-3-0.1-True-False-False...,exponential,0.5,12,RMSE,3,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.00017,0.00045,3,0.00018,0.00048,3,"(exponential, 0.5, 12, RMSE, 3, 0.1, True, Fal...","[1, 256, 42]"
11,mean-0.3-12-RMSE-3-0.1-True-False-False-mean-T...,mean,0.3,12,RMSE,3,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.00016,0.00045,2,0.00028,0.00071,2,"(mean, 0.3, 12, RMSE, 3, 0.1, True, False, Fal...","[1, 256, 42]"


In [20]:
df.to_csv('ml20m_llm_all_runs.csv')

In [9]:
def get_stats_seeds(df_all):
    def aggregate_mean(series):
        first_elem = list(series)[0]
        if isinstance(first_elem, str) or isinstance(first_elem, bool):
            return first_elem
        return series.mean()
        # print(df)
        # all_cols = set(df.columns)
        # df_mean = df.mean(numeric_only=True)
        # rest_cols = all_cols - set(df_mean.columns)
        # for col in rest_cols:
        #     df_mean[col] = list(df[col])[0]
        # return df_mean
    
    def aggregate_std(series):
        # print(series)
        first_elem = list(series)[0]
        if isinstance(first_elem, str) or isinstance(first_elem, bool):
            return first_elem
        if len(series) == 1:
            return 0
        return series.std()
        # all_cols = set(df.columns)
        # df_mean = df.std(numeric_only=True)
        # rest_cols = all_cols - set(df_mean.columns)
        # for col in rest_cols:
        #     df_mean[col] = list(df[col])[0]
        # return df_mean
    
    
    metric_cols = ["val_NDCG@10", "val_Recall@10", "val_epoch", "test_NDCG@10", "test_Recall@10"]
        
    grouped_df = df_all.drop(['config_file', 'all_seeds', 'seed'], axis=1).groupby('params')
    df_mean = grouped_df.agg(aggregate_mean)
    df_std = grouped_df.agg(aggregate_std)
    for col in metric_cols:
        df_mean[col + '__std'] = df_std[col]
    return df_mean


def reorder_cols(df):
    order = ['val_epoch', 'val_epoch__std'] + PARAMS[1:] + ['val_NDCG@10', 'val_NDCG@10__std', 'val_Recall@10', 'val_Recall@10__std', 'test_NDCG@10', 'test_NDCG@10__std', 'test_Recall@10', 'test_Recall@10__std']
    return df[order]


df_all = df[df['all_seeds'].apply(len) == 3]
df_final = get_stats_seeds(df_all)
df_final = reorder_cols(df_final)
df_final.sort_values(by='val_NDCG@10', ascending=False, inplace=True)
df_final

Unnamed: 0_level_0,val_epoch,val_epoch__std,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,...,scale_guide_loss,user_profile_embeddings_files,val_NDCG@10,val_NDCG@10__std,val_Recall@10,val_Recall@10__std,test_NDCG@10,test_NDCG@10__std,test_Recall@10,test_Recall@10__std
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(exponential, 0.5, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_single)",23.0,2.645751,exponential,0.5,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_single,0.024277,0.000631,0.048217,0.001644,0.010617,0.000832,0.021797,0.002435
"(exponential, 0.4, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_single)",23.666667,2.516611,exponential,0.4,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_single,0.02419,0.000422,0.04836,0.001044,0.010603,0.000693,0.021877,0.001958
"(exponential, 0.3, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_single)",23.333333,3.21455,exponential,0.3,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_single,0.024027,0.000372,0.046927,0.001053,0.010273,0.000682,0.020963,0.002044
"(mean, 0.4, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_single)",23.333333,2.081666,mean,0.4,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_single,0.023973,0.000868,0.04762,0.001343,0.01063,0.000453,0.021797,0.001279
"(exponential, 0.5, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",24.333333,1.154701,exponential,0.5,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.023863,0.001025,0.046827,0.001555,0.010717,0.00022,0.021717,0.000394
"(exponential, 0.6, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",24.0,0.0,exponential,0.6,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.023833,0.000542,0.046927,0.002165,0.010803,0.000653,0.021913,0.001316
"(mean, 0.3, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_single)",22.666667,2.081666,mean,0.3,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_single,0.0238,0.00046,0.04732,0.001155,0.010343,0.000653,0.021223,0.001812
"(exponential, 0.4, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",23.666667,0.57735,exponential,0.4,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.02372,0.00052,0.04663,0.001096,0.010533,0.000348,0.021677,0.000865
"(mean, 0.5, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_single)",23.0,1.732051,mean,0.5,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_single,0.023713,0.000896,0.047023,0.001765,0.010473,0.000339,0.021597,0.000869
"(mean, 0.4, 12, MSE, 3, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)",23.0,1.0,mean,0.4,12.0,MSE,3.0,0.1,True,False,...,True,gemma_short_large_umap_single,0.023683,0.000592,0.046577,0.001333,0.010263,0.000293,0.020767,0.000805


In [23]:
df_final.to_csv('BEAUTY_ALPHA_EXPS.csv')

In [22]:
df[df['params'] == ('exponential', 0.8, 30, 'MSE', 1, 0.1, 'True', 'False', 'False', 'mean', 'False', 'gemma_short_large_umap_single')]

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,...,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch,params,all_seeds


In [32]:
df['params'][0]

('mean',
 0.5,
 50,
 'MSE',
 1,
 0.1,
 'True',
 'False',
 'False',
 'mean',
 'False',
 'gemma_short_large_umap_single')

In [30]:
df[df['all_seeds'].apply(len) != 3]

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,...,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch,params,all_seeds


In [10]:
df_final['weighting_scheme'].value_counts()

weighting_scheme
mean           16
attention      15
exponential     9
Name: count, dtype: int64

In [14]:
df_final[df_final['scale_guide_loss'] == 'False']

Unnamed: 0_level_0,val_epoch,val_epoch__std,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,...,scale_guide_loss,user_profile_embeddings_files,val_NDCG@10,val_NDCG@10__std,val_Recall@10,val_Recall@10__std,test_NDCG@10,test_NDCG@10__std,test_Recall@10,test_Recall@10__std
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(mean, 0.5, 6, MSE, 1, 0.1, False, True, False, mean, False, gemma_short_large_umap_single)",14.0,0.0,mean,0.5,6.0,MSE,1.0,0.1,False,True,...,False,gemma_short_large_umap_single,0.099597,0.000808,0.1834,0.000946,0.059867,0.000621,0.1152,0.000375
"(mean, 0.5, 4, RMSE, 2, 0.1, False, True, False, mean, False, gemma_long_large_umap_single)",9.333333,1.527525,mean,0.5,4.0,RMSE,2.0,0.1,False,True,...,False,gemma_long_large_umap_single,0.099523,7.2e-05,0.182587,0.000343,0.059717,0.000998,0.114783,0.001294
"(mean, 0.5, 6, MSE, 1, 0.1, False, True, False, mean, False, gemma_long_large_umap_single)",14.0,0.0,mean,0.5,6.0,MSE,1.0,0.1,False,True,...,False,gemma_long_large_umap_single,0.099293,0.000384,0.18376,0.000537,0.059667,0.000748,0.114967,0.000718
"(exponential, 0.8, 8, RMSE, 2, 0.1, False, True, False, mean, False, gemma_long_large_umap_single)",12.666667,1.527525,exponential,0.8,8.0,RMSE,2.0,0.1,False,True,...,False,gemma_long_large_umap_single,0.099367,0.000376,0.183277,0.000603,0.05964,0.000927,0.114837,0.001643
"(exponential, 0.8, 8, RMSE, 1, 0.1, False, True, False, mean, False, gemma_short_large_umap_single)",13.0,1.732051,exponential,0.8,8.0,RMSE,1.0,0.1,False,True,...,False,gemma_short_large_umap_single,0.099387,0.000116,0.183037,0.000872,0.059383,0.000947,0.11402,0.000931
"(exponential, 0.5, 8, RMSE, 2, 0.1, False, True, False, mean, False, gemma_short_large_umap_single)",13.666667,0.57735,exponential,0.5,8.0,RMSE,2.0,0.1,False,True,...,False,gemma_short_large_umap_single,0.099817,0.000266,0.183417,0.001399,0.059373,0.001048,0.114633,0.002053
"(exponential, 0.8, 8, MSE, 1, 0.1, False, True, False, mean, False, gemma_long_large_umap_single)",14.666667,1.154701,exponential,0.8,8.0,MSE,1.0,0.1,False,True,...,False,gemma_long_large_umap_single,0.099327,0.000539,0.183563,0.000333,0.059353,0.000945,0.114333,0.000127
"(mean, 0.65, 8, RMSE, 1, 0.1, False, True, False, mean, False, gemma_short_large_umap_single)",14.0,0.0,mean,0.65,8.0,RMSE,1.0,0.1,False,True,...,False,gemma_short_large_umap_single,0.099287,0.000652,0.183017,0.001596,0.05933,0.000872,0.11389,0.000656
"(mean, 0.8, 6, RMSE, 1, 0.1, False, True, False, mean, False, gemma_short_large_umap_single)",14.0,0.0,mean,0.8,6.0,RMSE,1.0,0.1,False,True,...,False,gemma_short_large_umap_single,0.09929,0.000521,0.183927,0.000652,0.059287,0.001598,0.114723,0.00165
"(exponential, 0.65, 8, MSE, 1, 0.1, False, True, False, mean, False, gemma_short_large_umap_single)",14.0,0.0,exponential,0.65,8.0,MSE,1.0,0.1,False,True,...,False,gemma_short_large_umap_single,0.098963,0.0001,0.183003,0.000907,0.05925,0.00099,0.113957,0.001367


In [40]:
df['params'].value_counts()

params
(mean, 0.7, 12, MSE, 1, 0.1, True, False, False, mean, True, gemma_short_large_umap_single)         3
(mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, False, gemma_long_large_umap_single)         3
(mean, 0.6, 6, RMSE, 2, 0.1, True, False, False, mean, True, gemma_long_large_umap_single)          3
(attention, 0.5, 6, RMSE, 1, 0.1, True, False, False, mean, False, gemma_long_large_umap_single)    3
(mean, 0.7, 12, MSE, 2, 0.1, True, False, False, mean, False, gemma_long_large_single)              3
                                                                                                   ..
(mean, 0.6, 12, MSE, 1, 0.1, True, False, False, mean, False, gemma_long_large_umap_single)         1
(mean, 0.7, 6, RMSE, 1, 0.1, False, False, False, mean, True, gemma_long_large_umap_single)         1
(attention, 0.5, 6, MSE, 2, 0.1, True, False, False, mean, False, gemma_long_large_single)          1
(attention, 0.6, 6, MSE, 2, 0.1, False, False, False, mean, True, gemma_lon

In [18]:
df['']

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,...,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch,params,all_seeds
663,mean-0.8-6-RMSE-1-0.1-False-True-False-mean-Fa...,mean,0.80,6,RMSE,1,0.1,False,True,False,...,gemma_short_large_umap_single,42,0.09987,0.18422,14,0.06111,0.11661,14,"(mean, 0.8, 6, RMSE, 1, 0.1, False, True, Fals...",[42]
837,exponential-0.8-8-MSE-2-0.1-False-True-False-m...,exponential,0.80,8,MSE,2,0.1,False,True,False,...,gemma_long_large_umap_single,42,0.09898,0.18188,14,0.06083,0.11617,14,"(exponential, 0.8, 8, MSE, 2, 0.1, False, True...",[42]
779,exponential-0.65-8-RMSE-2-0.1-False-True-False...,exponential,0.65,8,RMSE,2,0.1,False,True,False,...,gemma_short_large_umap_single,42,0.09998,0.18356,14,0.06072,0.11579,14,"(exponential, 0.65, 8, RMSE, 2, 0.1, False, Tr...",[42]
649,mean-0.8-6-MSE-2-0.1-False-True-False-mean-Tru...,mean,0.80,6,MSE,2,0.1,False,True,False,...,gemma_long_large_umap_single,42,0.09983,0.18389,14,0.06072,0.11615,14,"(mean, 0.8, 6, MSE, 2, 0.1, False, True, False...",[42]
467,exponential-0.5-8-RMSE-2-0.1-True-False-False-...,exponential,0.50,8,RMSE,2,0.1,True,False,False,...,gemma_short_large_umap_single,42,0.09985,0.18454,14,0.06072,0.11703,14,"(exponential, 0.5, 8, RMSE, 2, 0.1, True, Fals...",[42]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
876,exponential-0.8-8-RMSE-2-0.1-True-False-False-...,exponential,0.80,8,RMSE,2,0.1,True,False,False,...,gemma_long_large_single,42,0.09951,0.18465,19,0.05589,0.10964,19,"(exponential, 0.8, 8, RMSE, 2, 0.1, True, Fals...",[42]
792,exponential-0.8-4-MSE-1-0.1-False-True-False-m...,exponential,0.80,4,MSE,1,0.1,False,True,False,...,gemma_long_large_single,42,0.09851,0.18318,11,0.05586,0.10847,11,"(exponential, 0.8, 4, MSE, 1, 0.1, False, True...",[42]
25,attention-0.65-8-MSE-1-0.1-False-True-False-me...,attention,0.65,8,MSE,1,0.1,False,True,False,...,gemma_long_large_umap_single,42,0.09773,0.18112,18,0.05583,0.10835,18,"(attention, 0.65, 8, MSE, 1, 0.1, False, True,...",[42]
274,attention-0.8-6-MSE-2-0.1-True-False-False-mea...,attention,0.80,6,MSE,2,0.1,True,False,False,...,gemma_short_large_single,42,0.09796,0.18210,7,0.05579,0.10824,7,"(attention, 0.8, 6, MSE, 2, 0.1, True, False, ...",[42]


In [10]:
df[PARAMS[1:]].to_json('best_ml20m.json', index=False, orient="records")

In [42]:
df_final[PARAMS[1:]].to_json('best_beauty_3seeds.json', index=False, orient="records")

In [17]:
results.keys()

dict_keys(['-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-True-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-True-gemma_short_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-False-mean-False-gemma_long_large_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-False-mean-False-gemma_long_large_umap_single-42.yaml', '-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-True-False-Fal

In [23]:
results['-2_0/configs/sasrec/beauty/BEAUTY_INITIAL/single_seed/attention-0.7-12-MSE-2-0.1-False-False-False-mean-False-gemma_long_large_umap_single-42.yaml']

{'validation_metrics': [{'epoch': 1,
   'loss': 10.828,
   'Recall@5': 0.004162330905306972,
   'NDCG@5': 0.0023390187078764875,
   'Recall@10': 0.006243496357960458,
   'NDCG@10': 0.00302255539974203,
   'Recall@20': 0.010851791288836034,
   'NDCG@20': 0.004198139363666408,
   'loss_recsys': np.float64(10.347532872800473),
   'loss_guide': np.float64(0.0)},
  {'epoch': 2,
   'loss': 5.2138,
   'Recall@5': 0.004013676230117437,
   'NDCG@5': 0.0024368647296024288,
   'Recall@10': 0.008324661810613945,
   'NDCG@10': 0.003826736836797046,
   'Recall@20': 0.014122194143005798,
   'NDCG@20': 0.0052561474396673165,
   'loss_recsys': np.float64(10.325713298938892),
   'loss_guide': np.float64(0.0)},
  {'epoch': 3,
   'loss': 4.8,
   'Recall@5': 0.006094841682770923,
   'NDCG@5': 0.003388410317555781,
   'Recall@10': 0.009216589861751152,
   'NDCG@10': 0.004382003442624782,
   'Recall@20': 0.015014122194143005,
   'NDCG@20': 0.005846106928100518,
   'loss_recsys': np.float64(10.315859370761448

In [22]:
df['config_file']

426    mean-0.7-12-MSE-1-0.1-True-False-False-mean-Tr...
11     attention-0.7-12-MSE-2-0.1-True-False-False-me...
88     mean-0.7-6-RMSE-1-0.1-True-False-False-mean-Tr...
283    mean-0.5-6-MSE-1-0.1-True-False-False-mean-Tru...
373    attention-0.5-12-MSE-1-0.1-True-False-False-me...
                             ...                        
73     mean-0.7-6-MSE-2-0.1-True-False-False-mean-Tru...
257    mean-0.7-12-MSE-2-0.1-True-False-False-mean-Tr...
234    mean-0.6-6-MSE-2-0.1-True-False-False-mean-Tru...
259    mean-0.7-12-MSE-2-0.1-True-False-False-mean-Tr...
2      attention-0.7-12-MSE-2-0.1-False-False-False-m...
Name: config_file, Length: 558, dtype: object

In [1]:
PARAMS

NameError: name 'PARAMS' is not defined

In [11]:
df_final

Unnamed: 0_level_0,val_epoch,val_epoch__std,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,...,scale_guide_loss,user_profile_embeddings_files,val_NDCG@10,val_NDCG@10__std,val_Recall@10,val_Recall@10__std,test_NDCG@10,test_NDCG@10__std,test_Recall@10,test_Recall@10__std
params,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(mean, 0.65, 130, RMSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",187.0,19.924859,mean,0.65,130.0,RMSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.181537,0.000266,0.288197,0.001323,0.04259,0.000977,0.073347,0.002034
"(mean, 0.65, 30, RMSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",196.333333,3.785939,mean,0.65,30.0,RMSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.181477,0.001445,0.287293,0.00312,0.042173,0.000328,0.07328,0.000868
"(mean, 0.8, 30, MSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",178.0,1.732051,mean,0.8,30.0,MSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.181333,0.001285,0.288867,0.002677,0.0412,0.001629,0.0717,0.003037
"(mean, 0.65, 130, MSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",174.333333,25.10644,mean,0.65,130.0,MSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.18107,0.000131,0.2879,0.001914,0.04258,0.002975,0.07462,0.0047
"(mean, 0.8, 130, MSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",168.0,18.248288,mean,0.8,130.0,MSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.181033,0.000687,0.287797,0.000582,0.04126,0.001527,0.07175,0.003095
"(mean, 0.8, 100, RMSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",178.0,6.0,mean,0.8,100.0,RMSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.18092,0.00056,0.286803,0.001107,0.04095,0.001562,0.07058,0.002951
"(exponential, 0.8, 30, MSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",177.0,18.520259,exponential,0.8,30.0,MSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.18074,0.000197,0.287647,0.001047,0.04789,0.000676,0.08185,0.001945
"(mean, 0.65, 100, MSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",162.333333,15.567059,mean,0.65,100.0,MSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.180707,0.000315,0.286637,0.001048,0.042893,0.003585,0.074783,0.005056
"(exponential, 0.8, 130, RMSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",180.666667,18.036999,exponential,0.8,130.0,RMSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.1806,0.00062,0.28759,5.2e-05,0.045393,0.002886,0.078397,0.003748
"(exponential, 0.8, 100, RMSE, 1, 0.1, True, False, False, mean, False, gemma_short_large_umap_single)",170.666667,7.767453,exponential,0.8,100.0,RMSE,1.0,0.1,True,False,...,False,gemma_short_large_umap_single,0.18059,0.000679,0.286117,0.000657,0.044977,0.002299,0.07753,0.002516


In [14]:
df_final[6:].to_csv('ml20m_llm_runs_FINAL_4_200_epochs.csv')

In [27]:
np.corrcoef(df_final['test_NDCG@10'].values[:15], df_final['val_NDCG@10'].values[:15])

array([[ 1.       , -0.5951883],
       [-0.5951883,  1.       ]])

In [13]:
df[df['params'] ==('exponential', 0.65, 50, 'MSE', 1, 0.1, 'True', 'False', 'False', 'mean', 'True', 'gemma_short_large_single')]

Unnamed: 0,config_file,weighting_scheme,alpha,fine_tune_epoch,reconstruct_loss,reconstruction_layer,weight_scale,use_down_scale,use_upscale,multi_profile,...,user_profile_embeddings_files,seed,val_NDCG@10,val_Recall@10,val_epoch,test_NDCG@10,test_Recall@10,test_epoch,params,all_seeds
12,exponential-0.65-50-MSE-1-0.1-True-False-False...,exponential,0.65,50,MSE,1,0.1,True,False,False,...,gemma_short_large_single,1,0.17211,0.27639,98,0.04652,0.0794,98,"(exponential, 0.65, 50, MSE, 1, 0.1, True, Fal...","[1, 256, 42]"
13,exponential-0.65-50-MSE-1-0.1-True-False-False...,exponential,0.65,50,MSE,1,0.1,True,False,False,...,gemma_short_large_single,256,0.17094,0.2775,100,0.04522,0.07727,100,"(exponential, 0.65, 50, MSE, 1, 0.1, True, Fal...","[1, 256, 42]"
14,exponential-0.65-50-MSE-1-0.1-True-False-False...,exponential,0.65,50,MSE,1,0.1,True,False,False,...,gemma_short_large_single,42,0.17012,0.27362,99,0.04753,0.08052,99,"(exponential, 0.65, 50, MSE, 1, 0.1, True, Fal...","[1, 256, 42]"
