In [1]:
import pandas as pd
import numpy as np

Load in `.csv` file that came from `stage-p2x`

In [2]:
df = pd.read_csv('runs_with_eval_loss_and_params.csv')

In [3]:
df

Unnamed: 0,run_name,min_eval_loss,hidden_size,attention_probs_dropout_prob,hidden_dropout_prob,intermediate_size,num_attention_heads,num_hidden_layers,learning_rate,pretraining_task
0,run_39,0.429169,209,0.176,0.128,3968,11,3,2e-06,10M-MLM
1,run_11,0.615919,112,0.118,0.183,4844,8,5,2e-06,10M-MLM
2,run_34,0.131611,696,0.148,0.226,8436,12,2,8.7e-05,10M-MLM
3,run_2,0.172452,82,0.232,0.16,11024,2,6,0.000144,10M-MLM
4,run_38,0.457947,126,0.109,0.279,456,3,2,2.1e-05,10M-MLM
5,run_45,0.147017,384,0.109,0.144,464,12,3,0.000141,10M-MLM
6,run_43,0.15832,324,0.201,0.126,5428,9,2,0.000262,10M-MLM
7,run_4,0.711699,344,0.235,0.139,1252,8,4,3e-06,10M-MLM
8,run_19,0.19644,57,0.129,0.139,10476,3,5,5.8e-05,10M-MLM
9,run_9,0.149599,580,0.249,0.121,5712,10,3,0.000279,10M-MLM


Load in data from MolNet finetuning

In [4]:
import s3fs 

In [5]:
import os
import json

In [6]:
fs = s3fs.S3FileSystem()

In [1]:
model_bucket = ""

In [7]:
cloud_dir = f"s3://{model_bucket}/chemberta/mlm_pretraining_10M_20210723/molnet_mlm_10M_ft_20210723/"

In [50]:
def get_dataframes(cloud_dir):
    run_dirs = fs.ls(cloud_dir)
    data_avg = []
    df_all = pd.DataFrame()
    for rd in run_dirs:
        run_name = os.path.basename(os.path.normpath(rd))
        # go one level down to get the molnet task
        molnet_task_data_avg = {}
        molnet_task_data_all = {}
        for molnet_task_dir in fs.ls(rd):
            molnet_task_name = os.path.basename(os.path.normpath(molnet_task_dir))
            results_dir = os.path.join(molnet_task_dir, "results/")
            for subset in ["valid", "test"]:
                with fs.open(os.path.join(results_dir, subset, "metrics.json")) as f:
                    metrics = json.load(f)
                # pick first item to get the keys
                metric_names = list(list(metrics.items())[0][1].keys())
                metric_res = {mn: [] for mn in metric_names}
                for seed, res in metrics.items():
                    for mn, mres in res.items():
                        if mn == "pearsonr":
                            metric_res[mn].append(mres[0])
                        else:
                            metric_res[mn].append(mres)
                molnet_task_data_all.update({f"{molnet_task_name}_{subset}_{mn}": metric_res[mn] for mn in metric_names})
                average_metrics = {f"{molnet_task_name}_{subset}_{mn}_mean": np.mean(metric_res[mn]) for mn in metric_names}
                std_metrics = {f"{molnet_task_name}_{subset}_{mn}_std": np.std(metric_res[mn]) for mn in metric_names}
                molnet_task_data_avg.update({**average_metrics, **std_metrics})
        molnet_task_data_all.update({"run_name": [run_name]*5})
        df_all = df_all.append(pd.DataFrame(molnet_task_data_all))
        data_avg.append({"run_name": run_name, **molnet_task_data_avg})

    df_avg = pd.DataFrame(data_avg)
    return df_all, df_avg

In [51]:
df_all, df_avg = get_dataframes(cloud_dir)

In [52]:
df_all

Unnamed: 0,bace_classification_valid_roc_auc_score,bace_classification_valid_average_precision_score,bace_classification_test_roc_auc_score,bace_classification_test_average_precision_score,bace_regression_valid_pearsonr,bace_regression_valid_rmse,bace_regression_test_pearsonr,bace_regression_test_rmse,bbbp_valid_roc_auc_score,bbbp_valid_average_precision_score,...,delaney_test_rmse,lipo_valid_pearsonr,lipo_valid_rmse,lipo_test_pearsonr,lipo_test_rmse,tox21_valid_roc_auc_score,tox21_valid_average_precision_score,tox21_test_roc_auc_score,tox21_test_average_precision_score,run_name
0,0.622246,0.692692,0.793841,0.862725,-0.230438,0.596198,-0.217982,1.276962,0.558036,0.639609,...,0.602184,0.576544,0.823305,0.512566,0.806596,0.762783,0.412171,0.729537,0.286098,run_11
1,0.617271,0.649605,0.772101,0.814909,-0.229989,0.595898,-0.216031,1.276871,0.557793,0.639575,...,0.538507,0.590941,0.813351,0.524219,0.790027,0.761239,0.368507,0.741719,0.285398,run_11
2,0.618515,0.651651,0.768659,0.818526,-0.230389,0.596087,-0.219816,1.276949,0.557939,0.639505,...,0.54178,0.602478,0.807255,0.50532,0.81251,0.734858,0.433442,0.74578,0.299844,run_11
3,0.64339,0.70372,0.809601,0.868508,-0.23028,0.59591,-0.218276,1.276893,0.557648,0.639309,...,0.568094,0.556328,0.835626,0.447331,0.826765,0.762268,0.433016,0.737609,0.343815,run_11
4,0.629353,0.699278,0.768841,0.829362,-0.230689,0.596143,-0.219356,1.276957,0.558036,0.639608,...,0.51918,0.569293,0.827344,0.491301,0.799713,0.785904,0.453306,0.746074,0.306306,run_11
0,0.520611,0.591752,0.580978,0.687419,0.268872,0.567899,0.653381,1.199541,0.591421,0.696895,...,0.72013,0.660857,0.759722,0.582182,0.750324,0.735308,0.404245,0.730173,0.37877,run_19
1,0.520789,0.591859,0.580978,0.687419,0.343638,0.559259,0.672247,1.170662,0.593071,0.690113,...,0.721768,0.670714,0.746713,0.624257,0.717793,0.730268,0.414122,0.712804,0.345822,run_19
2,0.520611,0.591752,0.580978,0.687419,0.353859,0.561093,0.675618,1.16582,0.589383,0.691106,...,0.712005,0.697282,0.730289,0.663698,0.696253,0.709463,0.43281,0.715593,0.341018,run_19
3,0.520789,0.591798,0.580978,0.687419,0.319156,0.566845,0.683307,1.181743,0.597826,0.695128,...,0.729394,0.675801,0.747815,0.633899,0.718338,0.700969,0.360835,0.724546,0.379368,run_19
4,0.520611,0.589814,0.580978,0.687419,0.313719,0.568147,0.665972,1.191766,0.593168,0.688426,...,0.723447,0.654816,0.76809,0.565669,0.759907,0.722268,0.380278,0.722491,0.316989,run_19


In [53]:
df_avg

Unnamed: 0,run_name,bace_classification_valid_roc_auc_score_mean,bace_classification_valid_average_precision_score_mean,bace_classification_valid_roc_auc_score_std,bace_classification_valid_average_precision_score_std,bace_classification_test_roc_auc_score_mean,bace_classification_test_average_precision_score_mean,bace_classification_test_roc_auc_score_std,bace_classification_test_average_precision_score_std,bace_regression_valid_pearsonr_mean,...,lipo_test_pearsonr_std,lipo_test_rmse_std,tox21_valid_roc_auc_score_mean,tox21_valid_average_precision_score_mean,tox21_valid_roc_auc_score_std,tox21_valid_average_precision_score_std,tox21_test_roc_auc_score_mean,tox21_test_average_precision_score_mean,tox21_test_roc_auc_score_std,tox21_test_average_precision_score_std
0,run_11,0.626155,0.679389,0.00959,0.023753,0.782609,0.838806,0.01642747,0.02247588,-0.230357,...,0.026639,0.012347,0.76141,0.420089,0.016172,0.028887,0.740144,0.304292,0.006138,0.021326
1,run_19,0.520682,0.591395,8.7e-05,0.000791,0.580978,0.687419,0.0,0.0,0.319849,...,0.035551,0.023326,0.719655,0.398458,0.012787,0.025318,0.721121,0.352393,0.006249,0.023872
2,run_34,0.685856,0.684942,0.020912,0.016583,0.730833,0.795026,0.04402426,0.028128,0.010563,...,0.010953,0.008062,0.770084,0.428939,0.00187,0.001061,0.752444,0.391813,0.001946,0.006686
3,run_38,0.471891,0.567802,7.1e-05,2e-05,0.547826,0.633643,4.9650680000000004e-17,4.9650680000000004e-17,0.178568,...,0.021279,0.016272,0.735398,0.40196,0.009267,0.021842,0.753726,0.345733,0.010551,0.016043
4,run_39,0.689126,0.697289,0.003346,0.004038,0.718623,0.779431,0.03411678,0.01993097,-0.00318,...,0.009715,0.006544,0.774108,0.415123,0.00764,0.011772,0.745301,0.298567,0.015782,0.015021
5,run_45,0.639943,0.683314,0.003346,0.002691,0.807572,0.840228,0.003605755,0.001992883,0.179036,...,0.013561,0.015164,0.776767,0.457633,0.010897,0.020736,0.740105,0.357602,0.006108,0.004955


In [57]:
combined_avg_df = pd.merge(left=df, right=df_avg, on='run_name')
# combined_avg_df['run_name'] = combined_avg_df['run_name'].apply(lambda x: f"mlm_{x}")

combined_all_df = pd.merge(left=df, right=df_all, on='run_name')
# combined_all_df['run_name'] = combined_all_df['run_name'].apply(lambda x: f"mlm_{x}")

In [58]:
combined_avg_df

Unnamed: 0,run_name,min_eval_loss,hidden_size,attention_probs_dropout_prob,hidden_dropout_prob,intermediate_size,num_attention_heads,num_hidden_layers,learning_rate,pretraining_task,...,lipo_test_pearsonr_std,lipo_test_rmse_std,tox21_valid_roc_auc_score_mean,tox21_valid_average_precision_score_mean,tox21_valid_roc_auc_score_std,tox21_valid_average_precision_score_std,tox21_test_roc_auc_score_mean,tox21_test_average_precision_score_mean,tox21_test_roc_auc_score_std,tox21_test_average_precision_score_std
0,run_39,0.429169,209,0.176,0.128,3968,11,3,2e-06,10M-MLM,...,0.009715,0.006544,0.774108,0.415123,0.00764,0.011772,0.745301,0.298567,0.015782,0.015021
1,run_11,0.615919,112,0.118,0.183,4844,8,5,2e-06,10M-MLM,...,0.026639,0.012347,0.76141,0.420089,0.016172,0.028887,0.740144,0.304292,0.006138,0.021326
2,run_34,0.131611,696,0.148,0.226,8436,12,2,8.7e-05,10M-MLM,...,0.010953,0.008062,0.770084,0.428939,0.00187,0.001061,0.752444,0.391813,0.001946,0.006686
3,run_38,0.457947,126,0.109,0.279,456,3,2,2.1e-05,10M-MLM,...,0.021279,0.016272,0.735398,0.40196,0.009267,0.021842,0.753726,0.345733,0.010551,0.016043
4,run_45,0.147017,384,0.109,0.144,464,12,3,0.000141,10M-MLM,...,0.013561,0.015164,0.776767,0.457633,0.010897,0.020736,0.740105,0.357602,0.006108,0.004955
5,run_19,0.19644,57,0.129,0.139,10476,3,5,5.8e-05,10M-MLM,...,0.035551,0.023326,0.719655,0.398458,0.012787,0.025318,0.721121,0.352393,0.006249,0.023872


In [59]:
combined_all_df

Unnamed: 0,run_name,min_eval_loss,hidden_size,attention_probs_dropout_prob,hidden_dropout_prob,intermediate_size,num_attention_heads,num_hidden_layers,learning_rate,pretraining_task,...,delaney_test_pearsonr,delaney_test_rmse,lipo_valid_pearsonr,lipo_valid_rmse,lipo_test_pearsonr,lipo_test_rmse,tox21_valid_roc_auc_score,tox21_valid_average_precision_score,tox21_test_roc_auc_score,tox21_test_average_precision_score
0,run_39,0.429169,209,0.176,0.128,3968,11,3,2e-06,10M-MLM,...,0.871144,0.509745,0.645184,0.769643,0.531337,0.782116,0.78273,0.409358,0.724644,0.319663
1,run_39,0.429169,209,0.176,0.128,3968,11,3,2e-06,10M-MLM,...,0.893529,0.469569,0.599654,0.806719,0.51414,0.788926,0.761625,0.408209,0.771858,0.277485
2,run_39,0.429169,209,0.176,0.128,3968,11,3,2e-06,10M-MLM,...,0.875403,0.512288,0.609248,0.797949,0.509164,0.795617,0.769647,0.42642,0.751504,0.299939
3,run_39,0.429169,209,0.176,0.128,3968,11,3,2e-06,10M-MLM,...,0.88888,0.484284,0.586889,0.815365,0.503509,0.800754,0.776381,0.431358,0.738343,0.287014
4,run_39,0.429169,209,0.176,0.128,3968,11,3,2e-06,10M-MLM,...,0.853968,0.541653,0.60497,0.804004,0.507554,0.796507,0.780156,0.400269,0.740154,0.308734
5,run_11,0.615919,112,0.118,0.183,4844,8,5,2e-06,10M-MLM,...,0.812892,0.602184,0.576544,0.823305,0.512566,0.806596,0.762783,0.412171,0.729537,0.286098
6,run_11,0.615919,112,0.118,0.183,4844,8,5,2e-06,10M-MLM,...,0.860755,0.538507,0.590941,0.813351,0.524219,0.790027,0.761239,0.368507,0.741719,0.285398
7,run_11,0.615919,112,0.118,0.183,4844,8,5,2e-06,10M-MLM,...,0.852103,0.54178,0.602478,0.807255,0.50532,0.81251,0.734858,0.433442,0.74578,0.299844
8,run_11,0.615919,112,0.118,0.183,4844,8,5,2e-06,10M-MLM,...,0.833494,0.568094,0.556328,0.835626,0.447331,0.826765,0.762268,0.433016,0.737609,0.343815
9,run_11,0.615919,112,0.118,0.183,4844,8,5,2e-06,10M-MLM,...,0.863538,0.51918,0.569293,0.827344,0.491301,0.799713,0.785904,0.453306,0.746074,0.306306


In [61]:
combined_avg_df.to_csv('ft_results_combined.csv', index=False)
combined_all_df.to_csv('ft_results_all_seeds.csv', index=False)