In [4]:
import pandas as pd
import numpy as np

Load in `.csv` file that came from `stage-p2x`

In [5]:
df = pd.read_csv('runs_with_eval_loss_and_params.csv')

Load in data from MolNet finetuning

In [6]:
import s3fs 

In [7]:
import os
import json

In [8]:
fs = s3fs.S3FileSystem()

In [1]:
model_bucket = ""

In [16]:
cloud_dir = f"s3://{model_bucket}/chemberta/mlm_pretraining_77M_20210729/molnet_mlm_77M_ft_20210729/"

In [17]:
def get_dataframes(cloud_dir):
    run_dirs = fs.ls(cloud_dir)
    data_avg = []
    df_all = pd.DataFrame()
    for rd in run_dirs:
        run_name = os.path.basename(os.path.normpath(rd))
        # go one level down to get the molnet task
        molnet_task_data_avg = {}
        molnet_task_data_all = {}
        for molnet_task_dir in fs.ls(rd):
            molnet_task_name = os.path.basename(os.path.normpath(molnet_task_dir))
            results_dir = os.path.join(molnet_task_dir, "results/")
            for subset in ["valid", "test"]:
                with fs.open(os.path.join(results_dir, subset, "metrics.json")) as f:
                    metrics = json.load(f)
                # pick first item to get the keys
                metric_names = list(list(metrics.items())[0][1].keys())
                metric_res = {mn: [] for mn in metric_names}
                for seed, res in metrics.items():
                    for mn, mres in res.items():
                        if mn == "pearsonr":
                            metric_res[mn].append(mres[0])
                        else:
                            metric_res[mn].append(mres)
                molnet_task_data_all.update({f"{molnet_task_name}_{subset}_{mn}": metric_res[mn] for mn in metric_names})
                average_metrics = {f"{molnet_task_name}_{subset}_{mn}_mean": np.mean(metric_res[mn]) for mn in metric_names}
                std_metrics = {f"{molnet_task_name}_{subset}_{mn}_std": np.std(metric_res[mn]) for mn in metric_names}
                molnet_task_data_avg.update({**average_metrics, **std_metrics})
        molnet_task_data_all.update({"run_name": [run_name]*5})
        df_all = df_all.append(pd.DataFrame(molnet_task_data_all))
        data_avg.append({"run_name": run_name, **molnet_task_data_avg})

    df_avg = pd.DataFrame(data_avg)
    return df_all, df_avg

In [18]:
df_all, df_avg = get_dataframes(cloud_dir)

In [19]:
df_all

Unnamed: 0,bace_classification_valid_roc_auc_score,bace_classification_valid_average_precision_score,bace_classification_test_roc_auc_score,bace_classification_test_average_precision_score,bace_regression_valid_pearsonr,bace_regression_valid_rmse,bace_regression_test_pearsonr,bace_regression_test_rmse,bbbp_valid_roc_auc_score,bbbp_valid_average_precision_score,...,delaney_test_rmse,lipo_valid_pearsonr,lipo_valid_rmse,lipo_test_pearsonr,lipo_test_rmse,tox21_valid_roc_auc_score,tox21_valid_average_precision_score,tox21_test_roc_auc_score,tox21_test_average_precision_score,run_name
0,0.658138,0.731607,0.769022,0.81259,0.025686,0.520232,0.775713,1.088148,0.953416,0.954161,...,0.557614,0.587488,0.814128,0.469805,0.812314,0.717914,0.319912,0.745829,0.285211,run_11
1,0.692786,0.730077,0.806703,0.85036,0.041414,0.524236,0.7678,1.093925,0.664596,0.673363,...,0.587209,0.608315,0.799419,0.510296,0.79568,0.769175,0.3835,0.731151,0.302215,run_11
2,0.676084,0.750218,0.797283,0.824736,0.034139,0.508618,0.786691,1.080424,0.953028,0.955238,...,0.577331,0.57765,0.821567,0.477865,0.817874,0.756778,0.390135,0.743138,0.26123,run_11
3,0.657249,0.72296,0.787319,0.825685,0.062565,0.510056,0.773089,1.074092,0.955648,0.955838,...,0.534699,0.566758,0.828631,0.488899,0.813623,0.747083,0.349745,0.738784,0.259383,run_11
4,0.688699,0.728319,0.805435,0.848996,0.043745,0.52486,0.775766,1.103782,0.733307,0.715206,...,0.551188,0.58545,0.81525,0.474445,0.818551,0.73949,0.368941,0.726993,0.269691,run_11
0,0.684435,0.682067,0.778623,0.818933,0.386013,0.526491,0.69696,1.066217,0.64606,0.700208,...,1.089734,0.675606,0.744013,0.611014,0.72799,0.695908,0.389212,0.708303,0.330024,run_19
1,0.663824,0.657262,0.787138,0.819535,0.351971,0.51946,0.698363,1.084981,0.647807,0.701861,...,1.089348,0.65849,0.759977,0.588983,0.742809,0.698181,0.351627,0.718822,0.373326,run_19
2,0.672175,0.655272,0.783514,0.810926,0.380552,0.522898,0.706767,1.103609,0.651203,0.702432,...,1.089461,0.651133,0.768026,0.580118,0.746829,0.708219,0.392658,0.722614,0.316314,run_19
3,0.676262,0.663668,0.781522,0.815367,0.380487,0.528667,0.683448,1.126033,0.649845,0.700853,...,1.089637,0.675172,0.744739,0.606664,0.733963,0.70646,0.336662,0.72391,0.288643,run_19
4,0.685679,0.687578,0.791667,0.825282,0.394321,0.535069,0.699458,1.130868,0.650718,0.703735,...,1.090029,0.694749,0.729115,0.626268,0.723935,0.714203,0.384258,0.739811,0.314118,run_19


In [20]:
df_avg

Unnamed: 0,run_name,bace_classification_valid_roc_auc_score_mean,bace_classification_valid_average_precision_score_mean,bace_classification_valid_roc_auc_score_std,bace_classification_valid_average_precision_score_std,bace_classification_test_roc_auc_score_mean,bace_classification_test_average_precision_score_mean,bace_classification_test_roc_auc_score_std,bace_classification_test_average_precision_score_std,bace_regression_valid_pearsonr_mean,...,lipo_test_pearsonr_std,lipo_test_rmse_std,tox21_valid_roc_auc_score_mean,tox21_valid_average_precision_score_mean,tox21_valid_roc_auc_score_std,tox21_valid_average_precision_score_std,tox21_test_roc_auc_score_mean,tox21_test_average_precision_score_mean,tox21_test_roc_auc_score_std,tox21_test_average_precision_score_std
0,run_11,0.674591,0.732636,0.014858,0.009263,0.793152,0.832473,0.013913,0.014793,0.04151,...,0.014461,0.008315,0.746088,0.362447,0.017237,0.025387,0.737179,0.275546,0.007113,0.016156
1,run_19,0.676475,0.669169,0.008083,0.013194,0.784493,0.818009,0.004532,0.004759,0.378669,...,0.01636,0.008643,0.704594,0.370884,0.006716,0.022498,0.722692,0.324485,0.010164,0.027839
2,run_38,0.480384,0.581494,0.000142,9.5e-05,0.40779,0.528025,0.0,0.0,0.17013,...,0.022283,0.012488,0.77041,0.430292,0.011813,0.024791,0.745829,0.302185,0.010798,0.004848
3,run_39,0.666098,0.686661,0.009845,0.009544,0.7225,0.770258,0.045955,0.021503,0.164242,...,0.015542,0.012064,0.779049,0.435612,0.008611,0.021197,0.745281,0.316157,0.017602,0.019095
4,run_45,0.613539,0.64057,0.0,0.0,0.560181,0.667201,7.2e-05,0.00058,0.152134,...,0.023401,0.012898,0.705079,0.41325,0.002234,0.002803,0.735437,0.331569,0.002723,0.002399


In [21]:
combined_avg_df = pd.merge(left=df, right=df_avg, on='run_name')
# combined_avg_df['run_name'] = combined_avg_df['run_name'].apply(lambda x: f"mlm_{x}")

combined_all_df = pd.merge(left=df, right=df_all, on='run_name')
# combined_all_df['run_name'] = combined_all_df['run_name'].apply(lambda x: f"mlm_{x}")

In [22]:
combined_avg_df

Unnamed: 0,run_name,min_eval_loss,hidden_size,attention_probs_dropout_prob,hidden_dropout_prob,intermediate_size,num_attention_heads,num_hidden_layers,learning_rate,pretraining_task,...,lipo_test_pearsonr_std,lipo_test_rmse_std,tox21_valid_roc_auc_score_mean,tox21_valid_average_precision_score_mean,tox21_valid_roc_auc_score_std,tox21_valid_average_precision_score_std,tox21_test_roc_auc_score_mean,tox21_test_average_precision_score_mean,tox21_test_roc_auc_score_std,tox21_test_average_precision_score_std
0,run_19,0.168913,57,0.129,0.139,10476,3,5,5.8e-05,77M-MLM,...,0.01636,0.008643,0.704594,0.370884,0.006716,0.022498,0.722692,0.324485,0.010164,0.027839
1,run_11,0.578779,112,0.118,0.183,4844,8,5,2e-06,77M-MLM,...,0.014461,0.008315,0.746088,0.362447,0.017237,0.025387,0.737179,0.275546,0.007113,0.016156
2,run_39,0.42175,209,0.176,0.128,3968,11,3,2e-06,77M-MLM,...,0.015542,0.012064,0.779049,0.435612,0.008611,0.021197,0.745281,0.316157,0.017602,0.019095
3,run_38,0.386785,126,0.109,0.279,456,3,2,2.1e-05,77M-MLM,...,0.022283,0.012488,0.77041,0.430292,0.011813,0.024791,0.745829,0.302185,0.010798,0.004848
4,run_45,0.136374,384,0.109,0.144,464,12,3,0.000141,77M-MLM,...,0.023401,0.012898,0.705079,0.41325,0.002234,0.002803,0.735437,0.331569,0.002723,0.002399


In [23]:
combined_all_df

Unnamed: 0,run_name,min_eval_loss,hidden_size,attention_probs_dropout_prob,hidden_dropout_prob,intermediate_size,num_attention_heads,num_hidden_layers,learning_rate,pretraining_task,...,delaney_test_pearsonr,delaney_test_rmse,lipo_valid_pearsonr,lipo_valid_rmse,lipo_test_pearsonr,lipo_test_rmse,tox21_valid_roc_auc_score,tox21_valid_average_precision_score,tox21_test_roc_auc_score,tox21_test_average_precision_score
0,run_19,0.168913,57,0.129,0.139,10476,3,5,5.8e-05,77M-MLM,...,0.437073,1.089734,0.675606,0.744013,0.611014,0.72799,0.695908,0.389212,0.708303,0.330024
1,run_19,0.168913,57,0.129,0.139,10476,3,5,5.8e-05,77M-MLM,...,0.439334,1.089348,0.65849,0.759977,0.588983,0.742809,0.698181,0.351627,0.718822,0.373326
2,run_19,0.168913,57,0.129,0.139,10476,3,5,5.8e-05,77M-MLM,...,0.439593,1.089461,0.651133,0.768026,0.580118,0.746829,0.708219,0.392658,0.722614,0.316314
3,run_19,0.168913,57,0.129,0.139,10476,3,5,5.8e-05,77M-MLM,...,0.439891,1.089637,0.675172,0.744739,0.606664,0.733963,0.70646,0.336662,0.72391,0.288643
4,run_19,0.168913,57,0.129,0.139,10476,3,5,5.8e-05,77M-MLM,...,0.440339,1.090029,0.694749,0.729115,0.626268,0.723935,0.714203,0.384258,0.739811,0.314118
5,run_11,0.578779,112,0.118,0.183,4844,8,5,2e-06,77M-MLM,...,0.841089,0.557614,0.587488,0.814128,0.469805,0.812314,0.717914,0.319912,0.745829,0.285211
6,run_11,0.578779,112,0.118,0.183,4844,8,5,2e-06,77M-MLM,...,0.820132,0.587209,0.608315,0.799419,0.510296,0.79568,0.769175,0.3835,0.731151,0.302215
7,run_11,0.578779,112,0.118,0.183,4844,8,5,2e-06,77M-MLM,...,0.836423,0.577331,0.57765,0.821567,0.477865,0.817874,0.756778,0.390135,0.743138,0.26123
8,run_11,0.578779,112,0.118,0.183,4844,8,5,2e-06,77M-MLM,...,0.857394,0.534699,0.566758,0.828631,0.488899,0.813623,0.747083,0.349745,0.738784,0.259383
9,run_11,0.578779,112,0.118,0.183,4844,8,5,2e-06,77M-MLM,...,0.853444,0.551188,0.58545,0.81525,0.474445,0.818551,0.73949,0.368941,0.726993,0.269691


In [24]:
combined_avg_df.to_csv('ft_results_combined.csv', index=False)
combined_all_df.to_csv('ft_results_all_seeds.csv', index=False)