In [75]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import re

In [76]:
def extract_parameters(filename):
    # Define the regular expression pattern to extract the parameters
    pattern = r'mse_corr(?P<mse_corr>[\d\.]+)_diff(?P<diff>[\d\.]+)_ssim(?P<ssim>[\d\.]+)_div(?P<div>[\d\.]+)\.csv'
    
    # Match the pattern against the filename
    match = re.match(pattern, filename)
    
    # Extract the parameters and convert them to float
    if match:
        params = {
            'corr': float(match.group('mse_corr')),
            'diff': float(match.group('diff')),
            'ssim': float(match.group('ssim')),
            'div': float(match.group('div'))
        }
        return params
    else:
        return None

In [77]:
# Define the directory containing the CSV files
dir = './results'
files = os.listdir(dir)

# Define the columns of the DataFrame
cols = ['lat', 'lbd', 'file_name', 'train_loss', 'val_loss', 'fc_diff_train', 'fc_diff_val', 'corr_train', 'corr_val']
df = pd.DataFrame(columns=cols)

# Iterate through each file in the directory
for file in files:
    if file.endswith('.csv'):
        # Read the CSV file into a temporary DataFrame
        temp_df = pd.read_csv(os.path.join(dir, file))
        
        # Add a column for the file name
        temp_df['file_name'] = file
        
        # Extract parameters from the file name and add them as columns
        params = extract_parameters(file)
        if params is not None:
            for key, value in params.items():
                temp_df[key] = value
        
        # Append the temporary DataFrame to the main DataFrame
        df = pd.concat([df, temp_df], ignore_index=True)

# drop lbd column
df = df.drop('lbd', axis=1)

# Print the resulting DataFrame
df.head()

  df = pd.concat([df, temp_df], ignore_index=True)


Unnamed: 0,lat,file_name,train_loss,val_loss,fc_diff_train,fc_diff_val,corr_train,corr_val,corr,diff,ssim,div
0,6,mse_corr0.0005_diff0.0005_ssim0.0005_div0.1.csv,1.00207,0.873928,0.169593,0.174031,0.257627,0.263212,0.0005,0.0005,0.0005,0.1
1,6,mse_corr0.0005_diff0.0005_ssim0.0005_div0.1.csv,0.796497,0.691415,0.169776,0.163963,0.367836,0.266597,0.0005,0.0005,0.0005,0.1
2,6,mse_corr0.0005_diff0.0005_ssim0.0005_div0.1.csv,0.673552,0.593045,0.161773,0.186425,0.392346,0.272627,0.0005,0.0005,0.0005,0.1
3,6,mse_corr0.0005_diff0.0005_ssim0.0005_div0.1.csv,0.630108,0.563485,0.145922,0.1473,0.437058,0.351155,0.0005,0.0005,0.0005,0.1
4,6,mse_corr0.0005_diff0.0005_ssim0.0005_div0.1.csv,0.608795,0.544221,0.125267,0.119031,0.460977,0.398583,0.0005,0.0005,0.0005,0.1


In [78]:
# Group by the hyperparameters and latent dimension
grouped = df.groupby(['lat', 'corr', 'diff', 'ssim', 'div'])

# Find the best performing run for each group based on the lowest validation loss
best_runs = grouped.apply(lambda x: x.loc[x['val_loss'].idxmin()])

# Reset index to get a clean DataFrame
best_runs.reset_index(drop=True, inplace=True)

In [79]:
best_runs

Unnamed: 0,lat,file_name,train_loss,val_loss,fc_diff_train,fc_diff_val,corr_train,corr_val,corr,diff,ssim,div
0,6,mse_corr0_diff0_ssim0_div0.1.csv,0.459475,0.427357,0.079160,0.033135,0.567320,0.631136,0.0000,0.0000,0.0000,0.1
1,6,mse_corr0_diff0_ssim0.0005_div0.1.csv,0.447489,0.431571,0.065162,0.043013,0.544203,0.559863,0.0000,0.0000,0.0005,0.1
2,6,mse_corr0_diff0_ssim0.005_div0.1.csv,0.411772,0.407401,0.090612,0.069676,0.509739,0.542241,0.0000,0.0000,0.0050,0.1
3,6,mse_corr0_diff0.0005_ssim0_div0.1.csv,0.442642,0.429785,0.070269,0.045401,0.568069,0.636237,0.0000,0.0005,0.0000,0.1
4,6,mse_corr0_diff0.0005_ssim0.0005_div0.1.csv,0.414113,0.411274,0.077295,0.073527,0.496637,0.494534,0.0000,0.0005,0.0005,0.1
...,...,...,...,...,...,...,...,...,...,...,...,...
160,16,mse_corr0.0005_diff0.0005_ssim0.0005_div0.1.csv,0.278120,0.241780,0.044463,0.028533,0.683368,0.673939,0.0005,0.0005,0.0005,0.1
161,16,mse_corr0.005_diff0_ssim0_div0.1.csv,0.305363,0.265618,0.041162,0.030320,0.668523,0.641925,0.0050,0.0000,0.0000,0.1
162,16,mse_corr0.005_diff0_ssim0.005_div0.1.csv,0.242290,0.203124,0.047995,0.028253,0.686892,0.651599,0.0050,0.0000,0.0050,0.1
163,16,mse_corr0.005_diff0.005_ssim0_div0.1.csv,0.275167,0.243343,0.038975,0.026175,0.690727,0.721439,0.0050,0.0050,0.0000,0.1


In [80]:
top_performers = best_runs.sort_values(by='val_loss').head(10)
top_performers

Unnamed: 0,lat,file_name,train_loss,val_loss,fc_diff_train,fc_diff_val,corr_train,corr_val,corr,diff,ssim,div
136,15,mse_corr0_diff0_ssim0.0005_div0.1.csv,0.235193,0.196677,0.041319,0.020949,0.696303,0.745745,0.0,0.0,0.0005,0.1
148,15,mse_corr0.005_diff0.005_ssim0_div0.1.csv,0.237035,0.198275,0.041711,0.020396,0.706941,0.743455,0.005,0.005,0.0,0.1
162,16,mse_corr0.005_diff0_ssim0.005_div0.1.csv,0.24229,0.203124,0.047995,0.028253,0.686892,0.651599,0.005,0.0,0.005,0.1
127,14,mse_corr0.0005_diff0_ssim0_div0.1.csv,0.256543,0.209588,0.038311,0.018061,0.695951,0.764503,0.0005,0.0,0.0,0.1
138,15,mse_corr0_diff0.0005_ssim0_div0.1.csv,0.250726,0.211179,0.034824,0.016137,0.699583,0.780684,0.0,0.0005,0.0,0.1
121,14,mse_corr0_diff0_ssim0.0005_div0.1.csv,0.252623,0.213478,0.039493,0.027148,0.735576,0.749088,0.0,0.0,0.0005,0.1
159,16,mse_corr0.0005_diff0.0005_ssim0_div0.1.csv,0.25455,0.213886,0.043669,0.029926,0.689753,0.668147,0.0005,0.0005,0.0,0.1
157,16,mse_corr0.0005_diff0_ssim0_div0.1.csv,0.252209,0.21514,0.03821,0.022762,0.710146,0.730936,0.0005,0.0,0.0,0.1
143,15,mse_corr0.0005_diff0_ssim0.0005_div0.1.csv,0.258585,0.216277,0.038412,0.020534,0.714709,0.7603,0.0005,0.0,0.0005,0.1
128,14,mse_corr0.0005_diff0_ssim0.0005_div0.1.csv,0.256077,0.216935,0.036846,0.025918,0.706146,0.704655,0.0005,0.0,0.0005,0.1


In [81]:
top_performers_corr = best_runs.sort_values(by='corr_val', ascending=False).head(10)
top_performers_corr

Unnamed: 0,lat,file_name,train_loss,val_loss,fc_diff_train,fc_diff_val,corr_train,corr_val,corr,diff,ssim,div
138,15,mse_corr0_diff0.0005_ssim0_div0.1.csv,0.250726,0.211179,0.034824,0.016137,0.699583,0.780684,0.0,0.0005,0.0,0.1
149,15,mse_corr0.005_diff0.005_ssim0.005_div0.1.csv,0.26526,0.231014,0.040395,0.01976,0.692458,0.771465,0.005,0.005,0.005,0.1
98,12,mse_corr0.0005_diff0_ssim0.0005_div0.1.csv,0.286353,0.251982,0.042611,0.017668,0.663266,0.76846,0.0005,0.0,0.0005,0.1
127,14,mse_corr0.0005_diff0_ssim0_div0.1.csv,0.256543,0.209588,0.038311,0.018061,0.695951,0.764503,0.0005,0.0,0.0,0.1
90,12,mse_corr0_diff0_ssim0_div0.1.csv,0.293557,0.257944,0.038378,0.017529,0.694973,0.762429,0.0,0.0,0.0,0.1
77,11,mse_corr0_diff0_ssim0.005_div0.1.csv,0.315493,0.285775,0.043501,0.023665,0.685423,0.761611,0.0,0.0,0.005,0.1
143,15,mse_corr0.0005_diff0_ssim0.0005_div0.1.csv,0.258585,0.216277,0.038412,0.020534,0.714709,0.7603,0.0005,0.0,0.0005,0.1
122,14,mse_corr0_diff0_ssim0.005_div0.1.csv,0.269638,0.235933,0.043504,0.019354,0.700798,0.758743,0.0,0.0,0.005,0.1
126,14,mse_corr0_diff0.005_ssim0.005_div0.1.csv,0.255772,0.220148,0.041344,0.016828,0.681446,0.756508,0.0,0.005,0.005,0.1
93,12,mse_corr0_diff0.0005_ssim0_div0.1.csv,0.29186,0.258191,0.040221,0.016625,0.674616,0.753459,0.0,0.0005,0.0,0.1


In [82]:
best_runs_no_corr = best_runs[(best_runs['corr'] == 0) & (best_runs['diff'] == 0) & (best_runs['ssim'] == 0)]
best_runs_no_corr

Unnamed: 0,lat,file_name,train_loss,val_loss,fc_diff_train,fc_diff_val,corr_train,corr_val,corr,diff,ssim,div
0,6,mse_corr0_diff0_ssim0_div0.1.csv,0.459475,0.427357,0.07916,0.033135,0.56732,0.631136,0.0,0.0,0.0,0.1
15,7,mse_corr0_diff0_ssim0_div0.1.csv,0.416366,0.396605,0.064008,0.040547,0.581604,0.58023,0.0,0.0,0.0,0.1
30,8,mse_corr0_diff0_ssim0_div0.1.csv,0.349174,0.337708,0.067203,0.046444,0.566891,0.590643,0.0,0.0,0.0,0.1
45,9,mse_corr0_diff0_ssim0_div0.1.csv,0.361151,0.337661,0.054643,0.044604,0.602813,0.664729,0.0,0.0,0.0,0.1
60,10,mse_corr0_diff0_ssim0_div0.1.csv,0.322559,0.296315,0.047704,0.045707,0.655121,0.692589,0.0,0.0,0.0,0.1
75,11,mse_corr0_diff0_ssim0_div0.1.csv,0.30208,0.273958,0.044765,0.029479,0.681217,0.67693,0.0,0.0,0.0,0.1
90,12,mse_corr0_diff0_ssim0_div0.1.csv,0.293557,0.257944,0.038378,0.017529,0.694973,0.762429,0.0,0.0,0.0,0.1
105,13,mse_corr0_diff0_ssim0_div0.1.csv,0.292193,0.263132,0.038932,0.045718,0.704519,0.652193,0.0,0.0,0.0,0.1
120,14,mse_corr0_diff0_ssim0_div0.1.csv,0.257622,0.226449,0.035438,0.028497,0.722688,0.689381,0.0,0.0,0.0,0.1
135,15,mse_corr0_diff0_ssim0_div0.1.csv,0.255873,0.21978,0.041905,0.031487,0.6922,0.671201,0.0,0.0,0.0,0.1


In [83]:
top_performers_same_params = best_runs.groupby(['corr', 'diff', 'ssim', 'div']).apply(lambda x: x.loc[x['val_loss'].idxmin()])
top_performers_same_params

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,lat,file_name,train_loss,val_loss,fc_diff_train,fc_diff_val,corr_train,corr_val,corr,diff,ssim,div
corr,diff,ssim,div,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0.0,0.0,0.0,0.1,15,mse_corr0_diff0_ssim0_div0.1.csv,0.255873,0.21978,0.041905,0.031487,0.6922,0.671201,0.0,0.0,0.0,0.1
0.0,0.0,0.0005,0.1,15,mse_corr0_diff0_ssim0.0005_div0.1.csv,0.235193,0.196677,0.041319,0.020949,0.696303,0.745745,0.0,0.0,0.0005,0.1
0.0,0.0,0.005,0.1,15,mse_corr0_diff0_ssim0.005_div0.1.csv,0.259459,0.220312,0.04185,0.037232,0.676668,0.722281,0.0,0.0,0.005,0.1
0.0,0.0005,0.0,0.1,15,mse_corr0_diff0.0005_ssim0_div0.1.csv,0.250726,0.211179,0.034824,0.016137,0.699583,0.780684,0.0,0.0005,0.0,0.1
0.0,0.0005,0.0005,0.1,15,mse_corr0_diff0.0005_ssim0.0005_div0.1.csv,0.25746,0.218879,0.040891,0.02364,0.692287,0.729939,0.0,0.0005,0.0005,0.1
0.0,0.005,0.0,0.1,16,mse_corr0_diff0.005_ssim0_div0.1.csv,0.27141,0.233412,0.043339,0.027902,0.660661,0.667959,0.0,0.005,0.0,0.1
0.0,0.005,0.005,0.1,14,mse_corr0_diff0.005_ssim0.005_div0.1.csv,0.255772,0.220148,0.041344,0.016828,0.681446,0.756508,0.0,0.005,0.005,0.1
0.0005,0.0,0.0,0.1,14,mse_corr0.0005_diff0_ssim0_div0.1.csv,0.256543,0.209588,0.038311,0.018061,0.695951,0.764503,0.0005,0.0,0.0,0.1
0.0005,0.0,0.0005,0.1,15,mse_corr0.0005_diff0_ssim0.0005_div0.1.csv,0.258585,0.216277,0.038412,0.020534,0.714709,0.7603,0.0005,0.0,0.0005,0.1
0.0005,0.0005,0.0,0.1,16,mse_corr0.0005_diff0.0005_ssim0_div0.1.csv,0.25455,0.213886,0.043669,0.029926,0.689753,0.668147,0.0005,0.0005,0.0,0.1


In [84]:
top_performers_same_params = best_runs.groupby(['corr', 'diff', 'ssim', 'div']).apply(lambda x: x.loc[x['corr_val'].idxmax()])
top_performers_same_params

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,lat,file_name,train_loss,val_loss,fc_diff_train,fc_diff_val,corr_train,corr_val,corr,diff,ssim,div
corr,diff,ssim,div,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0.0,0.0,0.0,0.1,12,mse_corr0_diff0_ssim0_div0.1.csv,0.293557,0.257944,0.038378,0.017529,0.694973,0.762429,0.0,0.0,0.0,0.1
0.0,0.0,0.0005,0.1,14,mse_corr0_diff0_ssim0.0005_div0.1.csv,0.252623,0.213478,0.039493,0.027148,0.735576,0.749088,0.0,0.0,0.0005,0.1
0.0,0.0,0.005,0.1,11,mse_corr0_diff0_ssim0.005_div0.1.csv,0.315493,0.285775,0.043501,0.023665,0.685423,0.761611,0.0,0.0,0.005,0.1
0.0,0.0005,0.0,0.1,15,mse_corr0_diff0.0005_ssim0_div0.1.csv,0.250726,0.211179,0.034824,0.016137,0.699583,0.780684,0.0,0.0005,0.0,0.1
0.0,0.0005,0.0005,0.1,14,mse_corr0_diff0.0005_ssim0.0005_div0.1.csv,0.272213,0.239326,0.039657,0.023367,0.679432,0.746251,0.0,0.0005,0.0005,0.1
0.0,0.005,0.0,0.1,7,mse_corr0_diff0.005_ssim0_div0.1.csv,0.425815,0.39774,0.074891,0.024668,0.576361,0.71242,0.0,0.005,0.0,0.1
0.0,0.005,0.005,0.1,14,mse_corr0_diff0.005_ssim0.005_div0.1.csv,0.255772,0.220148,0.041344,0.016828,0.681446,0.756508,0.0,0.005,0.005,0.1
0.0005,0.0,0.0,0.1,14,mse_corr0.0005_diff0_ssim0_div0.1.csv,0.256543,0.209588,0.038311,0.018061,0.695951,0.764503,0.0005,0.0,0.0,0.1
0.0005,0.0,0.0005,0.1,12,mse_corr0.0005_diff0_ssim0.0005_div0.1.csv,0.286353,0.251982,0.042611,0.017668,0.663266,0.76846,0.0005,0.0,0.0005,0.1
0.0005,0.0005,0.0,0.1,11,mse_corr0.0005_diff0.0005_ssim0_div0.1.csv,0.319585,0.289441,0.04652,0.020741,0.671815,0.728882,0.0005,0.0005,0.0,0.1


In [85]:
top_performers_same_params = best_runs.groupby(['corr', 'diff', 'ssim', 'div']).apply(lambda x: x.loc[x['fc_diff_val'].idxmin()])
top_performers_same_params

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,lat,file_name,train_loss,val_loss,fc_diff_train,fc_diff_val,corr_train,corr_val,corr,diff,ssim,div
corr,diff,ssim,div,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0.0,0.0,0.0,0.1,12,mse_corr0_diff0_ssim0_div0.1.csv,0.293557,0.257944,0.038378,0.017529,0.694973,0.762429,0.0,0.0,0.0,0.1
0.0,0.0,0.0005,0.1,16,mse_corr0_diff0_ssim0.0005_div0.1.csv,0.297727,0.262251,0.042987,0.019023,0.680534,0.721974,0.0,0.0,0.0005,0.1
0.0,0.0,0.005,0.1,14,mse_corr0_diff0_ssim0.005_div0.1.csv,0.269638,0.235933,0.043504,0.019354,0.700798,0.758743,0.0,0.0,0.005,0.1
0.0,0.0005,0.0,0.1,15,mse_corr0_diff0.0005_ssim0_div0.1.csv,0.250726,0.211179,0.034824,0.016137,0.699583,0.780684,0.0,0.0005,0.0,0.1
0.0,0.0005,0.0005,0.1,11,mse_corr0_diff0.0005_ssim0.0005_div0.1.csv,0.311381,0.284228,0.042672,0.019437,0.680807,0.720522,0.0,0.0005,0.0005,0.1
0.0,0.005,0.0,0.1,7,mse_corr0_diff0.005_ssim0_div0.1.csv,0.425815,0.39774,0.074891,0.024668,0.576361,0.71242,0.0,0.005,0.0,0.1
0.0,0.005,0.005,0.1,14,mse_corr0_diff0.005_ssim0.005_div0.1.csv,0.255772,0.220148,0.041344,0.016828,0.681446,0.756508,0.0,0.005,0.005,0.1
0.0005,0.0,0.0,0.1,14,mse_corr0.0005_diff0_ssim0_div0.1.csv,0.256543,0.209588,0.038311,0.018061,0.695951,0.764503,0.0005,0.0,0.0,0.1
0.0005,0.0,0.0005,0.1,12,mse_corr0.0005_diff0_ssim0.0005_div0.1.csv,0.286353,0.251982,0.042611,0.017668,0.663266,0.76846,0.0005,0.0,0.0005,0.1
0.0005,0.0005,0.0,0.1,10,mse_corr0.0005_diff0.0005_ssim0_div0.1.csv,0.329901,0.299173,0.048861,0.020257,0.632849,0.706591,0.0005,0.0005,0.0,0.1


In [None]:
# TODO: add the mse_corr.csv etc files to the dataframe