In [None]:
import mlflow
import datetime
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# get mlflow runs
experiment_id = '100992505006922144' # TODO update mlflow experiment ID if it changes (check mlruns directory)

mlflow.set_tracking_uri("../../mlruns")

runs = mlflow.search_runs(experiment_ids=[experiment_id])

failed_runs = len(runs[runs['status']=='FAILED'][['params.model']])
print("{} experiment runs failed ({}% of total)".format(failed_runs, failed_runs/len(runs)*100))

In [None]:
# timestamp for saving figures, tables and other outputs from this experiment run
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
print(timestamp)

In [None]:
runs.to_csv(f'results-{timestamp}.csv', index=None)

In [None]:
# cleanup the data fields for analysis
runs['trial'] = runs['params.datafile'].str.split('_').str[-2]
runs['causal_distance_type'] = runs['params.causal_distance'].str[:-1]
runs['causal_distance_no'] = runs['params.causal_distance'].str[-1].astype(int)

results = runs.sort_values(by='metrics.RMSE_avg_val').groupby(['causal_distance_type','causal_distance_no','trial']).first()
results = results[['metrics.RMSE_avg_val','metrics.RMSE_avg_test']].reset_index()

results['causal_dist_name'] = results['causal_distance_type'].str[:3]
causal_dist_no_map = {1:1, 2:2, 3:4, 4:8, 5:16, 6:32, 7:64}
results['causal_distance_no'] = results['causal_distance_no'].map(causal_dist_no_map)

In [None]:
# plot the results

sns.set_style("whitegrid")

fig, axes = plt.subplots(1, 2, figsize=(12,4), sharey=True)

ax1 = sns.lineplot(results[results['causal_distance_type'].str.contains('dist')], x='causal_distance_no', y='metrics.RMSE_avg_test', hue='causal_dist_name', style='causal_dist_name', ax=axes[0], errorbar='sd')
ax2 = sns.lineplot(results[results['causal_distance_type'].str.contains('struc')], x='causal_distance_no', y='metrics.RMSE_avg_test', hue='causal_dist_name', style='causal_dist_name', ax=axes[1], errorbar='sd')

sns.despine(left=True)
ax1.set_ylabel('RMSE of test tasks')
ax2.set_ylabel('RMSE of test tasks')
ax1.set_xlabel('α1 (100s), for fixed α2')
ax2.set_xlabel('α2, for fixed α1')

ax1.legend().set_title('')
ax2.legend().set_title('')

for ax in axes:
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=2, borderpad=0.1, columnspacing=0.5)

plt.savefig(f'results-alpha-{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()
