In [1]:
import os
project_dir = "c:/Users/sqin34/OneDrive - UW-Madison/Research/solvgnn/"
os.chdir(project_dir)

In [2]:
import pickle
import torch
from solvgnn.model.model_GNN import solvgcn_binary
from solvgnn.util.generate_dataset import solvent_dataset_binary
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

Using backend: pytorch


In [3]:
plt.rcParams.update({
    'text.usetex': True,
    'font.family':'serif',
    'font.serif':['Computer Modern'],
    'axes.labelsize':10,
    'xtick.labelsize':9,
    'ytick.labelsize':9})

In [4]:
saved_model_dir = project_dir + "results/job_220814_solvgcn_binary_catx/"

In [5]:
valid_ind_list = np.load(saved_model_dir + 'saved_model/valid_ind_list.npy')
valid_loss_list = []
for cv_id in range(5):
    valid_loss_list.append(np.load(saved_model_dir + 'saved_model/val_loss_cv{}.npy'.format(cv_id))[-1])
print("mean CV loss {:.4f}, std CV loss {:.4f}".format(np.mean(valid_loss_list), np.std(valid_loss_list)))
best_cv_id = np.argmin(valid_loss_list)
print("best CV fold: {}".format(best_cv_id))


mean CV loss 0.1808, std CV loss 0.0055
best CV fold: 1


In [6]:
idx_all = []
true_gam1_all, true_gam2_all = [], []
pred_gam1_all, pred_gam2_all = [], []
intra_hb1_all, intra_hb2_all = [], []
inter_hb_all = []
solv1_x_all = []

In [7]:
dataset_path = project_dir + "solvgnn/data/output_binary_all.csv"
solvent_list_path = project_dir + 'solvgnn/data/solvent_list.csv'
dataset = solvent_dataset_binary(
    input_file_path=dataset_path,
    solvent_list_path = solvent_list_path,
    generate_all=True)

In [8]:
for cv_id in range(5):
    print('Analayzing CV {}'.format(cv_id+1))
    valid_ind = valid_ind_list[cv_id]
    empty_solvsys = dataset.generate_solvsys(batch_size=1)
    model = solvgcn_binary(in_dim=74, hidden_dim=256, n_classes=1).cuda()
    checkpoint = torch.load(saved_model_dir + 'saved_model/final_model_cv{}.pth'.format(cv_id))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    for count,idx in enumerate(valid_ind):
        idx_all.append(idx)
        solv1_x_all.append(dataset[idx]['solv1_x'])
        intra_hb1_all.append(dataset[idx]['intra_hb1'])
        intra_hb2_all.append(dataset[idx]['intra_hb2'])
        inter_hb_all.append(dataset[idx]['inter_hb'])
        true,pred = dataset.predict(idx,model,empty_solvsys)
        pred_gam1_all.append(pred[0])
        pred_gam2_all.append(pred[1])
        true_gam1_all.append(true[0])
        true_gam2_all.append(true[1])
        if (count+1) % 500 == 0:
            print('{} out of {} done!'.format(count+1,len(valid_ind)))

Analayzing CV 1
500 out of 40000 done!
1000 out of 40000 done!
1500 out of 40000 done!
2000 out of 40000 done!
2500 out of 40000 done!
3000 out of 40000 done!
3500 out of 40000 done!
4000 out of 40000 done!
4500 out of 40000 done!
5000 out of 40000 done!
5500 out of 40000 done!
6000 out of 40000 done!
6500 out of 40000 done!
7000 out of 40000 done!
7500 out of 40000 done!
8000 out of 40000 done!
8500 out of 40000 done!
9000 out of 40000 done!
9500 out of 40000 done!
10000 out of 40000 done!
10500 out of 40000 done!
11000 out of 40000 done!
11500 out of 40000 done!
12000 out of 40000 done!
12500 out of 40000 done!
13000 out of 40000 done!
13500 out of 40000 done!
14000 out of 40000 done!
14500 out of 40000 done!
15000 out of 40000 done!
15500 out of 40000 done!
16000 out of 40000 done!
16500 out of 40000 done!
17000 out of 40000 done!
17500 out of 40000 done!
18000 out of 40000 done!
18500 out of 40000 done!
19000 out of 40000 done!
19500 out of 40000 done!
20000 out of 40000 done!
2050

In [9]:
output_cv = pd.DataFrame({'idx':np.array(idx_all),
                          'true_gam1':np.array(true_gam1_all),
                          'pred_gam1':np.array(pred_gam1_all),
                          'true_gam2':np.array(true_gam2_all),
                          'pred_gam2':np.array(pred_gam2_all),
                          'solv1_x':np.array(solv1_x_all),
                          'intra_hb1':np.array(intra_hb1_all),
                          'intra_hb2':np.array(intra_hb2_all),
                          'inter_hb':np.array(inter_hb_all)                 
                          })
output_cv.to_csv(saved_model_dir + 'analysis/output_cv.csv',index=False)

In [None]:
valid_ind_list = np.load('../saved_model/valid_ind_list.npy')
idx_all = []
true_gam1_all = []
true_gam2_all = []
pred_gam1_all = []
pred_gam2_all = []
solv1_x_all = []
# intra_hb1_all = []
# intra_hb2_all = []
# inter_hb_all = []
dataset_path = '../../../data/output_binary_all.csv'
solvent_list_path = '../../../data/solvent_list.csv'
dataset = solvent_dataset_binary(input_file_path=dataset_path,
                                 solvent_list_path=solvent_list_path)




import matplotlib
xmin = -33
xmax = 13
color_list = matplotlib.cm.get_cmap('tab10')
labels = ["nonpolar-nonpolar","polar-nonpolar","polar-polar"]
output_cv = pd.read_csv('../analysis/output_cv.csv')
output_cv = output_cv.sort_values(by="tpsa_binary_avg",ascending=False)
fig,ax = plt.subplots(1,2,figsize=(6,3))
ax[0].grid(color='lightgray',linewidth=0.75,alpha=0.5)
ax[0].plot([xmin, xmax], 
           [xmin, xmax], color='black',linestyle='--', lw=1)
ax[0].scatter(output_cv['true_gam1'],output_cv['pred_gam1'],c=color_list(output_cv['tpsa_binary_avg']),s=6,alpha=0.7)
ax[0].set_xlabel('True $\ln\gamma_1$')
ax[0].set_ylabel('Predicted $\ln\gamma_1$')
ax[0].set_title('$R^2$={:.2f},MAE={:.3f},RMSE={:.3f}'.format(r2_score(output_cv['true_gam1'],output_cv['pred_gam1']),
                                                  mean_absolute_error(output_cv['true_gam1'],output_cv['pred_gam1']),
                                                  np.sqrt(mean_squared_error(output_cv['true_gam1'],output_cv['pred_gam1']))))

ax[1].grid(color='lightgray',linewidth=0.75,alpha=0.5)
ax[1].plot([xmin, xmax], 
           [xmin, xmax], color='black',linestyle='--', lw=1)
ax[1].scatter(output_cv['true_gam2'],output_cv['pred_gam2'],c=color_list(output_cv['tpsa_binary_avg']),s=6,alpha=0.7)
ax[1].set_xlabel('True $\ln\gamma_2$')
ax[1].set_ylabel('Predicted $\ln\gamma_2$')

ax[1].set_title('$R^2$={:.2f},MAE={:.3f},RMSE={:.3f}'.format(r2_score(output_cv['true_gam2'],output_cv['pred_gam2']),
                                                  mean_absolute_error(output_cv['true_gam2'],output_cv['pred_gam2']),
                                                  np.sqrt(mean_squared_error(output_cv['true_gam2'],output_cv['pred_gam2']))))
plt.tight_layout()
plt.show()
plt.savefig('../analysis/cv_parity.svg',pad_inches=0,dpi=400,transparent=True)   
plt.close()

