In [3]:
import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
import yaml
import pickle
from pyteomics import mgf

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn


from rdkit import Chem
# ignore the warning
from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*')
from rdkit.Chem import Descriptors

from molmspack.molnet import MolNet_Oth, MolNet_MS
from molmspack.dataset import MolORNL_Dataset_For_Ploting
from molmspack.data_utils import csv2pkl_wfilter, mgf2pkl_wfilter, nce2ce, precursor_calculator

import h5py
from draw_spectra import draw_spectrum_direct

model_str = '3DMolMS_MSE'
data_type = 'validation'
normalization = 'MAX'
small_true = False
save_data_bool = False

if small_true:
    end_str_1 = '_small.hdf5'
    end_str_2 = '_small.csv'
else:
    end_str_1 = '.hdf5'
    end_str_2 = '.csv'

args = argparse.Namespace(
	test_data='./data/data_'+ data_type + end_str_1,
	save_pkl=True,
	model_config_path='./config/molnet.yml',
	data_config_path='./config/preprocess_uv-vis.yml',
	resume_path='./check_point/molnet_uv-vis_from_start_max_norm_transfer_p_2.pt',
	result_path='./check_point/predictions/'+ data_type +'/pred'+ end_str_2,
	true_output='y',
	seed=42,
	device=0,
	no_cuda=False
)
with open(args.model_config_path, 'r') as f: 
	config = yaml.load(f, Loader=yaml.FullLoader)
spectrum_discretization_step = config['model']['resolution']
xmin_spectrum = config['model']['min_wavelength']
xmax_spectrum = config['model']['max_wavelength'] + spectrum_discretization_step
w = 10


def gauss_torch(a, m, x, w, log_2):
    # calculation of the Gaussian line shape
    # a = amplitude (max y, intensity)
    # x = position
    # m = maximum/median (stick position in x, wave number)
    # w = line width, FWHM
	e = torch.exp(-(log_2 * ((m-x) / w) ** 2))
	return torch.einsum('i,ij->ij', a, e)


def pred_step(model, device, loader, batch_size, num_points): 
	model.eval()
	dict_pred = {}
	accuracy = []

	bins = int((xmax_spectrum - xmin_spectrum)/spectrum_discretization_step)
	x_spectra_cpu = np.arange(xmin_spectrum, xmax_spectrum, spectrum_discretization_step)
	x_spectra = torch.arange(xmin_spectrum, xmax_spectrum, spectrum_discretization_step, device=device)
	x_spectra_tensor = x_spectra.unsqueeze(0)  # Add an extra dimension at the beginning
	x_spectra_tensor = x_spectra_tensor.repeat(batch_size, 1)
	log_2 = torch.log(torch.tensor(2.0, device=device))
	with tqdm(total=len(loader)) as bar:
		for step, batch in enumerate(loader):
			title, smiles, x, y = batch
			x = x.to(device=device, dtype=torch.float)
			x = x.permute(0, 2, 1)
			y = y.to(device=device, dtype=torch.float)
			if normalization == 'MAX':
				y = y / torch.max(y, dim=1, keepdim=True)[0]
			elif normalization == 'SUM':
				y = y / torch.sum(y, dim=1, keepdim=True)
			#env = torch.arange().to(device=device, dtype=torch.float)
			idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points

			with torch.no_grad(): 
				pred = model(x, None, idx_base)
				pred = nn.LeakyReLU(0.1)(pred)

			gauss_sum = torch.zeros((batch_size, bins), dtype=torch.float, device=device)  
			for index, wn in enumerate(x_spectra):
				gauss_sum += gauss_torch(pred[:,index], x_spectra_tensor, wn, w, log_2)
			if normalization == 'MAX':
				gauss_sum = gauss_sum / torch.max(gauss_sum, dim=1, keepdim=True)[0]
			elif normalization == 'SUM':
				gauss_sum = gauss_sum / torch.sum(gauss_sum, dim=1, keepdim=True)
			gauss_sum = torch.pow(gauss_sum, 2)
			bar.set_description('Eval')
			bar.update(1)
			# For each mini-batch step in batch, draw the predictions and ground truth
			# Flatten batch
			score = nn.MSELoss()(gauss_sum, y).cpu().numpy().item()
			accuracy.append(score)
			#Plot every 100 steps
			y = y / torch.max(y, dim=1, keepdim=True)[0]
			gauss_sum = gauss_sum / torch.max(gauss_sum, dim=1, keepdim=True)[0]
			for i in range(batch_size):
				pred_tmp = pred[i,:].cpu().numpy()
				gauss_sum_tmp = gauss_sum[i,:].cpu().numpy()
				y_tmp = y[i,:].cpu().numpy()
				step_size = y_tmp.shape[0]
				for idx, intensity in enumerate(gauss_sum_tmp):
					y_save = y_tmp[idx]
					smiles_tmp = smiles[0]
					group_id = model_str + '_' + smiles_tmp + '_' + normalization
					wavelength_tmp = x_spectra_cpu[idx]
					dict_pred[step*step_size + idx] = [model_str, title[0], smiles_tmp, wavelength_tmp, intensity, y_save, group_id]
    
    
				if not save_data_bool:
					draw_spectrum_direct(gauss_sum_tmp/np.max(gauss_sum_tmp),
                          y_tmp/np.max(y_tmp), str(i),
                          max_wavelength=xmax_spectrum, 
                          min_wavelength=xmin_spectrum, 
                          resolution=spectrum_discretization_step)
			if not save_data_bool:
				if step == 5:
					return dict_pred, accuracy
	return dict_pred, accuracy

def init_random_seed(seed):
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	return


init_random_seed(args.seed)


with open(args.model_config_path, 'r') as f: 
	config = yaml.load(f, Loader=yaml.FullLoader)
print('Load the model & training configuration from {}'.format(args.model_config_path))
with open(args.data_config_path, 'r') as f: 
	data_config = yaml.load(f, Loader=yaml.FullLoader)
print('Load the data configuration from {}'.format(args.data_config_path))


valid_set = MolORNL_Dataset_For_Ploting(args.test_data, args.true_output, data_augmentation=False, partitioned=False)
valid_loader = DataLoader(
				valid_set,
				batch_size=1, 
				shuffle=False, 
				num_workers=config['train']['num_workers'], 
				drop_last=True)

# 2. Model
device = torch.device("cpu")

model = MolNet_MS(config['model']).to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f'{str(model)} #Params: {num_params}')

# 3. Evaluation
print("Load the checkpoints...")
model.load_state_dict(torch.load(args.resume_path, map_location=device)['model_state_dict'])

dict_pred, accuracy = pred_step(model, device, valid_loader, 
          batch_size=1, num_points=config['model']['max_atom_num'])
accuracy = np.array(accuracy)

print(np.mean(accuracy), np.std(accuracy))
if save_data_bool:

	# Write the dict_pred to csv file
	df_pred = pd.DataFrame.from_dict(dict_pred, orient='index', columns=['model', 'title', 'smiles', 'nm', 'pred', 'y', 'group_id'])
	df_pred.to_csv(args.result_path, index=False)
	display(df_pred.head())



Load 8301 data from ./data/data_validation.hdf5
MolNet_MS(
  (encoder): Encoder(
    (hidden_layers): ModuleList(
      (0): MolConv k = 6 (15 -> 64)
      (1): MolConv k = 6 (64 -> 64)
      (2): MolConv k = 6 (64 -> 128)
      (3): MolConv k = 6 (128 -> 256)
      (4): MolConv k = 6 (256 -> 512)
      (5): MolConv k = 6 (512 -> 1024)
    )
    (conv): Sequential(
      (0): Conv1d(2048, 2048, kernel_size=(1,), stride=(1,), bias=False)
      (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (merge): Sequential(
      (0): Linear(in_features=4096, out_features=2048, bias=True)
      (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (decoder): MSDecoder(
    (blocks): ModuleList(
      (0-6): 7 x FCResBlock (2048 -> 2048)
    )
    (fc): Linear(in_features=2048, out_features=171, bias=True)
  )
) #Params: 101815

Eval: 100%|██████████| 8301/8301 [04:03<00:00, 34.13it/s]


0.01795123675512785 0.017595261302713446
