In [None]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader
from dgl import model_zoo,DGLGraph
import math
import networkx as nx
import matplotlib.pyplot as plt
from utils_mpnn import Meter, set_random_seed, collate, EarlyStopping, load_model,load_brl_dataset,regress,run_a_train_epoch,run_an_eval_epoch
import argparse
# from sklearn import svm, datasets
# from sklearn.model_selection import train_test_split
# from sklearn.metrics import plot_confusion_matrix
import seaborn as sn
import pandas as pd
from sklearn import metrics
import collections
import sys; sys.argv=['']; del sys
parser = argparse.ArgumentParser(description='Molecule Regression')
parser.add_argument('-m', '--model', type=str,default='MPNN',help='Model to use')#choices=['MPNN', 'SCHNET', 'MGCN', 'AttentiveFP'],
#parser.add_argument('-d', '--dataset', type=str, default='bridge',help='Dataset to use')#choices=['Alchemy', 'Aromaticity'],                
parser.add_argument('-p', '--pre-trained', action='store_true', default=False, help='Whether to skip training and use a pre-trained model')
args = parser.parse_args().__dict__
training_setting= {
    'random_seed': 0,
    'batch_size': 16,
    'num_epochs': 500,
    'node_in_feats': 4,
    'edge_in_feats': 6,
    'output_dim': 1,
    'lr': 0.0001,#0.001,#
    'patience': 100,
    'metric_name': 'l1',
    'weight_decay': 0,
    'n_task':41,
}
args.update(training_setting)
args['device'] = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
args['data_path_x']='data/data_mpnn/edgetype/'
args['data_path_label']='data/'
set_random_seed(args['random_seed'])

print(torch.cuda.current_device())
# torch.cuda.set_device(1)
# print(torch.cuda.current_device())

train_loader,val_loader,test_loader=load_brl_dataset(args)#for batch_id, batch_data in enumerate(train_loader):bg, labels = batch_data;print(a)
if args['pre_trained']:
    args['num_epochs'] = 0
    model = model_zoo.chem.load_pretrained(args['exp'])
else:
    model = load_model(args)
    if args['model'] in ['SCHNET', 'MGCN']:
        model.set_mean_std(train_set.mean, train_set.std, args['device'])
    loss_fn =nn.L1Loss(reduction='none')
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    stopper = EarlyStopping(mode='lower', patience=args['patience'], filename='model_saved/edgetype/early_stop.pth')
model.to(args['device'])

In [None]:
for epoch in range(args['num_epochs']):
    st=time.time()
    if epoch==250:
        torch.save({'model_state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()}, 'model_saved/edgetype/1.pth')
        for param_group in optimizer.param_groups: print(param_group['lr'])
        for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr']*0.1
        for param_group in optimizer.param_groups: print(param_group['lr'])
    # Train
    run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)
    # Validation and early stop
    val_score = run_an_eval_epoch(args, model, val_loader)
    early_stop = stopper.step(val_score, model,optimizer)
    print('epoch {:d}/{:d}, validation {} {:.4f}, best validation {} {:.4f}, time{:.1f}'.format(
        epoch + 1, args['num_epochs'], args['metric_name'], val_score,
        args['metric_name'], stopper.best_score,time.time()-st))
    #if early_stop:break
torch.save({'model_state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()}, 'model_saved/edgetype/last.pth')

In [None]:
state=torch.load('model_saved/edgetype/early_stop.pth') 
model.load_state_dict(state['model_state_dict'])

model.eval()
#eval_meter = Meter()
y_p=[];y_t=[]
with torch.no_grad():
    for batch_id, batch_data in enumerate(test_loader):
        bg, labels = batch_data
        labels = labels.to(args['device'])  
        y_t.append(labels.cpu().detach().numpy())
        prediction = regress(args, model, bg)
        y_p.append(prediction.cpu().detach().numpy())
y_p=np.concatenate(y_p,axis=0).reshape((-1))
y_t=np.concatenate(y_t,axis=0).reshape((-1))
m=355311.0926130971;std=201425.32593642248
pred=y_p*std+m
actual=y_t*std+m

In [None]:
#####results - all
y_t=actual.copy();y_p=pred.copy()
print(y_t.shape,y_p.shape)#(296,) (296,)
import sklearn.metrics as metrics
mae = metrics.mean_absolute_error(y_t, y_p)
mse = metrics.mean_squared_error(y_t, y_p)
rmse=np.sqrt(mse)
print('mae', mae, '| rmse:',rmse)
print('actual mean', np.mean(y_t),'| pred mean',np.mean(y_p))
print(np.corrcoef(y_t.flatten(),y_p.flatten()))
print('rmse/range',rmse/(np.max(y_t)-np.min(y_t)))
print('mape',np.mean(np.abs(y_t-y_p)/y_t))
iqr= np.subtract(*np.percentile(y_t, [75, 25]))
print('rmse/iqr',rmse/iqr)
print('rmse/mean',rmse/np.mean(y_t))
print('actual min max', np.min(y_t),np.max(y_t))
print('pred min max', np.min(y_p),np.max(y_p))
plt.plot(y_t.flatten(),y_p.flatten(),'.',color='darkblue',alpha=0.8)
plt.ylim(17500-50000,700000+50000)#np.min([y_p,y_t]), np.max([y_p,y_t]))
plt.xlim(17500-50000,700000+50000)#np.min([y_p,y_t]), np.max([y_p,y_t]))
plt.xlabel('actual')
plt.ylabel('pred')
plt.grid()
plt.savefig('images/edgetype/test1.png')
plt.show()

#####results - 100~600
idx=np.where(((actual>=(100*1000)) & (actual<=(600*1000))))
print(len(idx[0]))
y_t=actual[idx[0]]
y_p=pred[idx[0]]
print(y_t.shape,y_p.shape)#(296,) (296,)
import sklearn.metrics as metrics
mae = metrics.mean_absolute_error(y_t, y_p)
mse = metrics.mean_squared_error(y_t, y_p)
rmse=np.sqrt(mse)
print('mae', mae, '| rmse:',rmse)
print('actual mean', np.mean(y_t),'| pred mean',np.mean(y_p))
print(np.corrcoef(y_t.flatten(),y_p.flatten()))
print('rmse/range',rmse/(np.max(y_t)-np.min(y_t)))
print('mape',np.mean(np.abs(y_t-y_p)/y_t))
iqr= np.subtract(*np.percentile(y_t, [75, 25]))
print('rmse/iqr',rmse/iqr)
print('rmse/mean',rmse/np.mean(y_t))
print('actual min max', np.min(y_t),np.max(y_t))
print('pred min max', np.min(y_p),np.max(y_p))
plt.plot(y_t.flatten(),y_p.flatten(),'.',color='darkblue',alpha=0.8)
plt.ylim(17500-50000,700000+50000)#np.min([y_p,y_t]), np.max([y_p,y_t]))
plt.xlim(17500-50000,700000+50000)#np.min([y_p,y_t]), np.max([y_p,y_t]))
plt.xlabel('actual')
plt.ylabel('pred')
plt.grid()
plt.savefig('images/edgetype/test2.png')
plt.show()