In [1]:
import os
project_dir = "c:/Users/sqin34/OneDrive - UW-Madison/Research/solvgnn/"
os.chdir(project_dir)

In [2]:
import pickle
import torch
from solvgnn.model.model_GNN import solvcat_ternary, get_n_params
from solvgnn.util.generate_dataset import solvent_dataset_ternary
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

Using backend: pytorch


In [3]:
plt.rcParams.update({
    'text.usetex': True,
    'font.family':'serif',
    'font.serif':['Computer Modern'],
    'axes.labelsize':10,
    'xtick.labelsize':9,
    'ytick.labelsize':9})

In [4]:
saved_model_dir = project_dir + "results/job_220822_solvcat_ternary_catx_randflip/"

In [5]:
valid_ind_list = np.load(saved_model_dir + 'saved_model/valid_ind_list.npy')
valid_loss_list = []
for cv_id in range(5):
    valid_loss_list.append(np.load(saved_model_dir + 'saved_model/val_loss_cv{}.npy'.format(cv_id))[-1])
print("mean CV loss {:.4f}, std CV loss {:.4f}".format(np.mean(valid_loss_list), np.std(valid_loss_list)))
best_cv_id = np.argmin(valid_loss_list)
print("best CV fold: {}".format(best_cv_id))

mean CV loss 0.0107, std CV loss 0.0004
best CV fold: 3


In [6]:
idx_all = []
true_gam1_all, true_gam2_all, true_gam3_all = [], [], []
pred_gam1_all, pred_gam2_all, pred_gam3_all = [], [], []
intra_hb1_all, intra_hb2_all, intra_hb3_all = [], [], []
inter_hb12_all, inter_hb13_all, inter_hb23_all = [], [], []
solv1_x_all, solv2_x_all = [], []

In [7]:
dataset_path = project_dir + "solvgnn/data/output_ternary_all.csv"
solvent_list_path = project_dir + 'solvgnn/data/solvent_list.csv'
dataset = solvent_dataset_ternary(
    input_file_path=dataset_path,
    solvent_list_path = solvent_list_path,
    generate_all=True)

In [9]:
cv_id = 0
valid_ind = valid_ind_list[cv_id]
empty_solvsys = dataset.generate_solvsys(batch_size=1)
model = solvcat_ternary(in_dim=74, hidden_dim=256, n_classes=3).cuda()
checkpoint = torch.load(saved_model_dir + 'saved_model/final_model_cv{}.pth'.format(cv_id))
print("# model params: {}".format(get_n_params(model)))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# model params: 349187


solvcat_ternary(
  (conv1): GraphConv(in=74, out=256, normalization=both, activation=None)
  (conv2): GraphConv(in=256, out=256, normalization=both, activation=None)
  (classify1): Linear(in_features=771, out_features=256, bias=True)
  (classify2): Linear(in_features=256, out_features=256, bias=True)
  (classify3): Linear(in_features=256, out_features=3, bias=True)
)

In [8]:
for cv_id in range(5):
    print('Analayzing CV {}'.format(cv_id+1))
    valid_ind = valid_ind_list[cv_id]
    model = solvcat_ternary(in_dim=74, hidden_dim=256, n_classes=3).cuda()
    checkpoint = torch.load(saved_model_dir + 'saved_model/final_model_cv{}.pth'.format(cv_id))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    for count,idx in enumerate(valid_ind):
        idx_all.append(idx)
        intra_hb1_all.append(dataset[idx]['intra_hb1'])
        intra_hb2_all.append(dataset[idx]['intra_hb2'])
        intra_hb3_all.append(dataset[idx]['intra_hb3'])
        inter_hb12_all.append(dataset[idx]['inter_hb12'])
        inter_hb13_all.append(dataset[idx]['inter_hb13'])
        inter_hb23_all.append(dataset[idx]['inter_hb23'])
        solv1_x_all.append(dataset[idx]['solv1_x'])
        solv2_x_all.append(dataset[idx]['solv2_x'])
        true,pred = dataset.predict(idx,model)
        pred_gam1_all.append(pred[0])
        pred_gam2_all.append(pred[1])
        pred_gam3_all.append(pred[2])
        true_gam1_all.append(true[0])
        true_gam2_all.append(true[1])
        true_gam3_all.append(true[2])
        if (count+1) % 100 == 0:
            print('{} out of {} done!'.format(count+1,len(valid_ind)))

Analayzing CV 1
100 out of 32000 done!
200 out of 32000 done!
300 out of 32000 done!
400 out of 32000 done!
500 out of 32000 done!
600 out of 32000 done!
700 out of 32000 done!
800 out of 32000 done!
900 out of 32000 done!
1000 out of 32000 done!
1100 out of 32000 done!
1200 out of 32000 done!
1300 out of 32000 done!
1400 out of 32000 done!
1500 out of 32000 done!
1600 out of 32000 done!
1700 out of 32000 done!
1800 out of 32000 done!
1900 out of 32000 done!
2000 out of 32000 done!
2100 out of 32000 done!
2200 out of 32000 done!
2300 out of 32000 done!
2400 out of 32000 done!
2500 out of 32000 done!
2600 out of 32000 done!
2700 out of 32000 done!
2800 out of 32000 done!
2900 out of 32000 done!
3000 out of 32000 done!
3100 out of 32000 done!
3200 out of 32000 done!
3300 out of 32000 done!
3400 out of 32000 done!
3500 out of 32000 done!
3600 out of 32000 done!
3700 out of 32000 done!
3800 out of 32000 done!
3900 out of 32000 done!
4000 out of 32000 done!
4100 out of 32000 done!
4200 out 

In [9]:
output_cv = pd.DataFrame({'idx':np.array(idx_all),
                          'solv1_x':np.array(solv1_x_all),
                          'solv2_x':np.array(solv2_x_all),
                          'true_gam1':np.array(true_gam1_all),
                          'pred_gam1':np.array(pred_gam1_all),
                          'true_gam2':np.array(true_gam2_all),
                          'pred_gam2':np.array(pred_gam2_all),
                          'true_gam3':np.array(true_gam3_all),
                          'pred_gam3':np.array(pred_gam3_all),
                          'intra_hb1':np.array(intra_hb1_all),
                          'intra_hb2':np.array(intra_hb2_all),
                          'intra_hb3':np.array(intra_hb3_all),
                          'inter_hb12':np.array(inter_hb12_all),
                          'inter_hb13':np.array(inter_hb13_all),
                          'inter_hb23':np.array(inter_hb23_all)                 
                          })
output_cv.to_csv(saved_model_dir + 'analysis/output_cv.csv',index=False)

In [None]:




import pandas as pd
import numpy as np
import matplotlib
xmin = -21
xmax = 8
output_cv = pd.read_csv('../analysis/output_cv.csv')
output_cv = output_cv.sort_values(by="tpsa_binary_avg",ascending=False)
color_list = matplotlib.cm.get_cmap('tab10')
labels = ["nonpolar-nonpolar-nonpolar","polar-nonpolar-nonpolar","polar-polar-nonpolar","polar-polar-polar"]
fig,ax = plt.subplots(1,3,figsize=(9,3))
ax[0].grid(color='lightgray',linewidth=0.75,alpha=0.5)
ax[1].grid(color='lightgray',linewidth=0.75,alpha=0.5)
ax[2].grid(color='lightgray',linewidth=0.75,alpha=0.5)
ax[0].plot([xmin, xmax], 
           [xmin, xmax], color='black',linestyle='--', lw=1)
ax[1].plot([xmin, xmax], 
           [xmin, xmax], color='black',linestyle='--', lw=1)
ax[2].plot([xmin, xmax], 
           [xmin, xmax], color='black',linestyle='--', lw=1)
ax[0].scatter(output_cv['true_gam1'],output_cv['pred_gam1'],c=color_list(output_cv['tpsa_binary_avg']),s=6,alpha=0.7)
ax[1].scatter(output_cv['true_gam2'],output_cv['pred_gam2'],c=color_list(output_cv['tpsa_binary_avg']),s=6,alpha=0.7)
ax[2].scatter(output_cv['true_gam3'],output_cv['pred_gam3'],c=color_list(output_cv['tpsa_binary_avg']),s=6,alpha=0.7)

ax[0].set_xlabel('True $\ln\gamma_1$')
ax[0].set_ylabel('Predicted $\ln\gamma_1$')
ax[0].set_title('$R^2$={:.2f},MAE={:.3f},RMSE={:.3f}'.format(r2_score(output_cv['true_gam1'],output_cv['pred_gam1']),
                                                  mean_absolute_error(output_cv['true_gam1'],output_cv['pred_gam1']),
                                                  np.sqrt(mean_squared_error(output_cv['true_gam1'],output_cv['pred_gam1']))))

ax[1].set_xlabel('True $\ln\gamma_2$')
ax[1].set_ylabel('Predicted $\ln\gamma_2$')

ax[1].set_title('$R^2$={:.2f},MAE={:.3f},RMSE={:.3f}'.format(r2_score(output_cv['true_gam2'],output_cv['pred_gam2']),
                                                  mean_absolute_error(output_cv['true_gam2'],output_cv['pred_gam2']),
                                                  np.sqrt(mean_squared_error(output_cv['true_gam2'],output_cv['pred_gam2']))))
ax[2].set_xlabel('True $\ln\gamma_3$')
ax[2].set_ylabel('Predicted $\ln\gamma_3$')

ax[2].set_title('$R^2$={:.2f},MAE={:.3f},RMSE={:.3f}'.format(r2_score(output_cv['true_gam3'],output_cv['pred_gam3']),
                                                  mean_absolute_error(output_cv['true_gam3'],output_cv['pred_gam3']),
                                                  np.sqrt(mean_squared_error(output_cv['true_gam3'],output_cv['pred_gam3']))))
for i in range(3):
    ax[i].set_xticks(np.arange(-20,6,5))
plt.tight_layout()
plt.show()
plt.savefig('../analysis/cv_parity_color.png',pad_inches=0,dpi=400,transparent=True)   
plt.close()








    