In [210]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import deepchem as dc
from tqdm import tqdm_notebook
from scipy.spatial.distance import norm

from fastai.tabular import *
from fastai.callbacks import SaveModelCallback
from fastai.basic_data import DataBunch

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [141]:
DATA_PATH = '../data/'
PATH = '../tmp'

In [142]:
files = os.listdir(DATA_PATH)
files = [f for f in files if f.find('.csv') != -1]
files

['scalar_coupling_contributions.csv',
 'mulliken_charges.csv',
 'structures.csv',
 'test.csv',
 'train.csv',
 'magnetic_shielding_tensors.csv',
 'dipole_moments.csv',
 'sample_submission.csv',
 'potential_energy.csv']

In [143]:
train_df = pd.read_csv(DATA_PATH+'train.csv')
test_df = pd.read_csv(DATA_PATH+'test.csv')
structures_df = pd.read_csv(DATA_PATH+'structures.csv')

## Get Molecules

In [404]:
##
# Written by Jan H. Jensen based on this paper Yeonjoon Kim and Woo Youn Kim 
# "Universal Structure Conversion Method for Organic Molecules: From Atomic Connectivity
# to Three-Dimensional Geometry" Bull. Korean Chem. Soc. 2015, Vol. 36, 1769-1777 DOI: 10.1002/bkcs.10334
#
from rdkit import Chem
from rdkit.Chem import AllChem
import itertools
from rdkit.Chem import rdmolops
from collections import defaultdict
import copy
import networkx as nx #uncomment if you don't want to use "quick"/install networkx


global __ATOM_LIST__
__ATOM_LIST__ = [ x.strip() for x in ['h ','he', \
      'li','be','b ','c ','n ','o ','f ','ne', \
      'na','mg','al','si','p ','s ','cl','ar', \
      'k ','ca','sc','ti','v ','cr','mn','fe','co','ni','cu', \
      'zn','ga','ge','as','se','br','kr', \
      'rb','sr','y ','zr','nb','mo','tc','ru','rh','pd','ag', \
      'cd','in','sn','sb','te','i ','xe', \
      'cs','ba','la','ce','pr','nd','pm','sm','eu','gd','tb','dy', \
      'ho','er','tm','yb','lu','hf','ta','w ','re','os','ir','pt', \
      'au','hg','tl','pb','bi','po','at','rn', \
      'fr','ra','ac','th','pa','u ','np','pu'] ]


def get_atom(atom):
    global __ATOM_LIST__
    atom = atom.lower()
    return __ATOM_LIST__.index(atom) + 1


def getUA(maxValence_list, valence_list):
    UA = []
    DU = []
    for i, (maxValence,valence) in enumerate(zip(maxValence_list, valence_list)):
        if maxValence - valence > 0:
            UA.append(i)
            DU.append(maxValence - valence)
    return UA,DU


def get_BO(AC,UA,DU,valences,UA_pairs,quick):
    BO = AC.copy()
    DU_save = []

    while DU_save != DU:
        for i,j in UA_pairs:
            BO[i,j] += 1
            BO[j,i] += 1 
        
        BO_valence = list(BO.sum(axis=1))
        DU_save = copy.copy(DU)
        UA, DU = getUA(valences, BO_valence)
        UA_pairs = get_UA_pairs(UA,AC,quick)[0]

    return BO


def valences_not_too_large(BO,valences):
    number_of_bonds_list = BO.sum(axis=1)
    for valence, number_of_bonds in zip(valences,number_of_bonds_list):
        if number_of_bonds > valence:
            return False

    return True


def BO_is_OK(BO,AC,charge,DU,atomic_valence_electrons,atomicNumList,charged_fragments):
    Q = 0 # total charge
    q_list = []
    if charged_fragments:
        BO_valences = list(BO.sum(axis=1))
        for i,atom in enumerate(atomicNumList):
            q = get_atomic_charge(atom,atomic_valence_electrons[atom],BO_valences[i])
            Q += q
            if atom == 6:
                number_of_single_bonds_to_C = list(BO[i,:]).count(1)
                if number_of_single_bonds_to_C == 2 and BO_valences[i] == 2:
                    Q += 1
                    q = 2
                if number_of_single_bonds_to_C == 3 and Q + 1 < charge:
                    Q += 2
                    q = 1
            
            if q != 0:
                q_list.append(q)

    if (BO-AC).sum() == sum(DU) and charge == Q and len(q_list) <= abs(charge):
        return True
    else:
        return False


def get_atomic_charge(atom,atomic_valence_electrons,BO_valence):
    if atom == 1:
        charge = 1 - BO_valence
    elif atom == 5:
        charge = 3 - BO_valence
    elif atom == 15 and BO_valence == 5:
        charge = 0
    elif atom == 16 and BO_valence == 6:
        charge = 0
    else:
        charge = atomic_valence_electrons - 8 + BO_valence

    return charge

def clean_charges(mol):
    # this hack should not be needed any more but is kept just in case

    rxn_smarts = ['[N+:1]=[*:2]-[C-:3]>>[N+0:1]-[*:2]=[C-0:3]',
                  '[N+:1]=[*:2]-[O-:3]>>[N+0:1]-[*:2]=[O-0:3]',
                  '[N+:1]=[*:2]-[*:3]=[*:4]-[O-:5]>>[N+0:1]-[*:2]=[*:3]-[*:4]=[O-0:5]',
                  '[#8:1]=[#6:2]([!-:6])[*:3]=[*:4][#6-:5]>>[*-:1][*:2]([*:6])=[*:3][*:4]=[*+0:5]',
                  '[O:1]=[c:2][c-:3]>>[*-:1][*:2][*+0:3]',
                  '[O:1]=[C:2][C-:3]>>[*-:1][*:2]=[*+0:3]']

    fragments = Chem.GetMolFrags(mol,asMols=True,sanitizeFrags=False)

    for i,fragment in enumerate(fragments):
        for smarts in rxn_smarts:
            patt = Chem.MolFromSmarts(smarts.split(">>")[0])
            while fragment.HasSubstructMatch(patt):
                rxn = AllChem.ReactionFromSmarts(smarts)
                ps = rxn.RunReactants((fragment,))
                fragment = ps[0][0]
        if i == 0:
            mol = fragment
        else:
            mol = Chem.CombineMols(mol,fragment)

    return mol


def BO2mol(mol,BO_matrix, atomicNumList,atomic_valence_electrons,mol_charge,charged_fragments):
    # based on code written by Paolo Toscani

    l = len(BO_matrix)
    l2 = len(atomicNumList)
    BO_valences = list(BO_matrix.sum(axis=1))

    if (l != l2):
        raise RuntimeError('sizes of adjMat ({0:d}) and atomicNumList '
            '{1:d} differ'.format(l, l2))

    rwMol = Chem.RWMol(mol)

    bondTypeDict = {
        1: Chem.BondType.SINGLE,
        2: Chem.BondType.DOUBLE,
        3: Chem.BondType.TRIPLE
    }

    for i in range(l):
        for j in range(i + 1, l):
            bo = int(round(BO_matrix[i, j]))
            if (bo == 0):
                continue
            bt = bondTypeDict.get(bo, Chem.BondType.SINGLE)
            rwMol.AddBond(i, j, bt)
    mol = rwMol.GetMol()

    if charged_fragments:
        mol = set_atomic_charges(mol,atomicNumList,atomic_valence_electrons,BO_valences,BO_matrix,mol_charge)
    else:
        mol = set_atomic_radicals(mol,atomicNumList,atomic_valence_electrons,BO_valences)

    return mol

def set_atomic_charges(mol,atomicNumList,atomic_valence_electrons,BO_valences,BO_matrix,mol_charge):
    q = 0
    for i,atom in enumerate(atomicNumList):
        a = mol.GetAtomWithIdx(i)
        charge = get_atomic_charge(atom,atomic_valence_electrons[atom],BO_valences[i])
        q += charge
        if atom == 6:
            number_of_single_bonds_to_C = list(BO_matrix[i,:]).count(1)
            if number_of_single_bonds_to_C == 2 and BO_valences[i] == 2:
                    q += 1
                    charge = 0
            if number_of_single_bonds_to_C == 3 and q + 1 < mol_charge:
                    q += 2
                    charge = 1

        if (abs(charge) > 0):
            a.SetFormalCharge(int(charge))

    # shouldn't be needed anymore bit is kept just in case
    #mol = clean_charges(mol)

    return mol


def set_atomic_radicals(mol,atomicNumList,atomic_valence_electrons,BO_valences):
    # The number of radical electrons = absolute atomic charge
    for i,atom in enumerate(atomicNumList):
        a = mol.GetAtomWithIdx(i)
        charge = get_atomic_charge(atom,atomic_valence_electrons[atom],BO_valences[i])

        if (abs(charge) > 0):
            a.SetNumRadicalElectrons(abs(int(charge)))

    return mol

def get_bonds(UA,AC):
    bonds = []

    for k,i in enumerate(UA):
        for j in UA[k+1:]:
            if AC[i,j] == 1:
                bonds.append(tuple(sorted([i,j])))

    return bonds

def get_UA_pairs(UA,AC,quick):
    bonds = get_bonds(UA,AC)
    if len(bonds) == 0:
        return [()]

    if quick:
        G=nx.Graph()
        G.add_edges_from(bonds)
        UA_pairs = [list(nx.max_weight_matching(G))]
        return UA_pairs

    max_atoms_in_combo = 0
    UA_pairs = [()]
    for combo in list(itertools.combinations(bonds, int(len(UA)/2))):
        flat_list = [item for sublist in combo for item in sublist]
        atoms_in_combo = len(set(flat_list))
        if atoms_in_combo > max_atoms_in_combo:
            max_atoms_in_combo = atoms_in_combo
            UA_pairs = [combo]
 #           if quick and max_atoms_in_combo == 2*int(len(UA)/2):
 #               return UA_pairs
        elif atoms_in_combo == max_atoms_in_combo:
            UA_pairs.append(combo)

    return UA_pairs

def AC2BO(AC,atomicNumList,charge,charged_fragments,quick):
    # TODO
    atomic_valence = defaultdict(list)
    atomic_valence[1] = [1]
    atomic_valence[6] = [4]
    atomic_valence[7] = [4,3]
    atomic_valence[8] = [2,1]
    atomic_valence[9] = [1]
    atomic_valence[14] = [4]
    atomic_valence[15] = [5,4,3]
    atomic_valence[16] = [6,4,2]
    atomic_valence[17] = [1]
    atomic_valence[32] = [4]
    atomic_valence[35] = [1]
    atomic_valence[53] = [1]


    atomic_valence_electrons = {}
    atomic_valence_electrons[1] = 1
    atomic_valence_electrons[6] = 4
    atomic_valence_electrons[7] = 5
    atomic_valence_electrons[8] = 6
    atomic_valence_electrons[9] = 7
    atomic_valence_electrons[14] = 4
    atomic_valence_electrons[15] = 5
    atomic_valence_electrons[16] = 6
    atomic_valence_electrons[17] = 7
    atomic_valence_electrons[32] = 4
    atomic_valence_electrons[35] = 7
    atomic_valence_electrons[53] = 7

    # make a list of valences, e.g. for CO: [[4],[2,1]]
    valences_list_of_lists = []
    for atomicNum in atomicNumList:
        valences_list_of_lists.append(atomic_valence[atomicNum])

    # convert [[4],[2,1]] to [[4,2],[4,1]]
    valences_list = list(itertools.product(*valences_list_of_lists))

    best_BO = AC.copy()

    # implemenation of algorithm shown in Figure 2
    # UA: unsaturated atoms
    # DU: degree of unsaturation (u matrix in Figure)
    # best_BO: Bcurr in Figure 
    #

    for valences in valences_list:
        AC_valence = list(AC.sum(axis=1))
        UA,DU_from_AC = getUA(valences, AC_valence)

        if len(UA) == 0 and BO_is_OK(AC,AC,charge,DU_from_AC,atomic_valence_electrons,atomicNumList,charged_fragments):
            return AC,atomic_valence_electrons
        
        UA_pairs_list = get_UA_pairs(UA,AC,quick) 
        for UA_pairs in UA_pairs_list:
            BO = get_BO(AC,UA,DU_from_AC,valences,UA_pairs,quick)
            if BO_is_OK(BO,AC,charge,DU_from_AC,atomic_valence_electrons,atomicNumList,charged_fragments):
                return BO,atomic_valence_electrons

            elif BO.sum() >= best_BO.sum() and valences_not_too_large(BO,valences):
                best_BO = BO.copy()

    return best_BO,atomic_valence_electrons


def AC2mol(mol,AC,atomicNumList,charge,charged_fragments,quick):
    # convert AC matrix to bond order (BO) matrix
    BO,atomic_valence_electrons = AC2BO(AC,atomicNumList,charge,charged_fragments,quick)

    # add BO connectivity and charge info to mol object
    mol = BO2mol(mol,BO, atomicNumList,atomic_valence_electrons,charge,charged_fragments)

    return mol


def get_proto_mol(atomicNumList):
    mol = Chem.MolFromSmarts("[#"+str(atomicNumList[0])+"]")
    rwMol = Chem.RWMol(mol)
    for i in range(1,len(atomicNumList)):
        a = Chem.Atom(atomicNumList[i])
        rwMol.AddAtom(a)
    
    mol = rwMol.GetMol()

    return mol


def get_atomicNumList(atomic_symbols):
    atomicNumList = []
    for symbol in atomic_symbols:
        atomicNumList.append(get_atom(symbol))
    return atomicNumList


def read_xyz_file(filename):

    atomic_symbols = []
    xyz_coordinates = []

    with open(filename, "r") as file:
        for line_number,line in enumerate(file):
            if line_number == 0:
                num_atoms = int(line)
            elif line_number == 1:
                if "charge=" in line:
                    charge = int(line.split("=")[1])
                else:
                    charge = 0
            else:
                atomic_symbol, x, y, z = line.split()
                atomic_symbols.append(atomic_symbol)
                xyz_coordinates.append([float(x),float(y),float(z)])

    atomicNumList = get_atomicNumList(atomic_symbols)
    
    return atomicNumList,charge,xyz_coordinates

def xyz2AC(atomicNumList,xyz):
    import numpy as np
    mol = get_proto_mol(atomicNumList)

    conf = Chem.Conformer(mol.GetNumAtoms())
    for i in range(mol.GetNumAtoms()):
        conf.SetAtomPosition(i,(xyz[i][0],xyz[i][1],xyz[i][2]))
    mol.AddConformer(conf)

    dMat = Chem.Get3DDistanceMatrix(mol)
    pt = Chem.GetPeriodicTable()

    num_atoms = len(atomicNumList)
    AC = np.zeros((num_atoms,num_atoms)).astype(int)

    for i in range(num_atoms):
        a_i = mol.GetAtomWithIdx(i)
        Rcov_i = pt.GetRcovalent(a_i.GetAtomicNum())*1.30
        for j in range(i+1,num_atoms):
            a_j = mol.GetAtomWithIdx(j)
            Rcov_j = pt.GetRcovalent(a_j.GetAtomicNum())*1.30
            if dMat[i,j] <= Rcov_i + Rcov_j:
                AC[i,j] = 1
                AC[j,i] = 1

    return AC,mol,dMat

def chiral_stereo_check(mol):
    Chem.SanitizeMol(mol)
    Chem.DetectBondStereochemistry(mol,-1)
    Chem.AssignStereochemistry(mol, flagPossibleStereoCenters=True, force=True)
    Chem.AssignAtomChiralTagsFromStructure(mol,-1)

    return mol


In [405]:
def xyz2mol(atomicNumList, charge, xyz_coordinates, charged_fragments, quick):

    # Get atom connectivity (AC) matrix, list of atomic numbers, molecular charge, 
    # and mol object with no connectivity information
    AC,mol,dMat = xyz2AC(atomicNumList, xyz_coordinates)

    # Convert AC to bond order matrix and add connectivity and charge info to mol object
    new_mol = AC2mol(mol, AC, atomicNumList, charge, charged_fragments, quick)

    # Check for stereocenters and chiral centers
    new_mol = chiral_stereo_check(new_mol)

    return new_mol,dMat

In [408]:
def mol_from_xyz(filepath, add_hs=True):
    charged_fragments = True  # alternatively radicals are made

    # quick is faster for large systems but requires networkx
    # if you don't want to install networkx set quick=False and
    # uncomment 'import networkx as nx' at the top of the file
    quick = True

    atomicNumList, charge, xyz_coordinates = read_xyz_file(filepath)
    mol, dMat = xyz2mol(atomicNumList, charge, xyz_coordinates, charged_fragments, quick)
    
    # Compute distance from centroid
    xyz_coord_array = np.array(xyz_coordinates)
    centroid = xyz_coord_array.mean(axis=0)
    dFromCentroid = norm(xyz_coord_array - centroid, axis=1).shape

    # Canonical hack
#     smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
#     mol = Chem.MolFromSmiles(smiles)
#     if add_hs: mol = Chem.AddHs(mol)
    return mol, dMat, dFromCentroid

In [409]:
from glob import glob
xyz_filepath_list = list(glob(DATA_PATH+'structures/*.xyz'))
xyz_filepath_list.sort()
n_mols = len(xyz_filepath_list)
print('total xyz filepath # ', n_mols)

total xyz filepath #  130775


In [410]:
dist_matrices = {}
mols = {}
dist_from_centroids = {}
for i in tqdm_notebook(range(n_mols)):
    filepath = xyz_filepath_list[i]
    mol_name = filepath.split('/')[-1][:-4]
    try: 
        mol, dist_matrix, dist_from_centroid = mol_from_xyz(filepath)
        mols[mol_name] = mol
        dist_matrices[mol_name] = dist_matrix
        dist_from_centroids[mol_name] = dist_from_centroid
    except ValueError as e: 
        print(mol_name, e)

HBox(children=(IntProgress(value=0, max=130775), HTML(value='')))

dsgdb9nsd_017732 Sanitization error: Explicit valence for atom # 4 C greater than permitted
dsgdb9nsd_037494 Sanitization error: Explicit valence for atom # 4 C greater than permitted
dsgdb9nsd_037900 Sanitization error: Explicit valence for atom # 5 C greater than permitted
dsgdb9nsd_042676 Sanitization error: Explicit valence for atom # 3 C greater than permitted
dsgdb9nsd_042681 Sanitization error: Explicit valence for atom # 3 C greater than permitted
dsgdb9nsd_044308 Sanitization error: Explicit valence for atom # 5 C greater than permitted
dsgdb9nsd_044322 Sanitization error: Explicit valence for atom # 5 C greater than permitted
dsgdb9nsd_048903 Sanitization error: Explicit valence for atom # 2 C greater than permitted
dsgdb9nsd_066495 Sanitization error: Explicit valence for atom # 7 C greater than permitted
dsgdb9nsd_067109 Sanitization error: Explicit valence for atom # 2 C greater than permitted
dsgdb9nsd_073323 Sanitization error: Explicit valence for atom # 4 C greater tha

In [658]:
N_EDGE_FEATURES = 16
N_ATOM_FEATURES = 26
MAX_N_ATOMS     = 29
MAX_N_BONDS     = 58
TYPES           = train_df['type'].unique()

In [659]:
# def get_edge_features(mol, eucl_dist):
#     """
#     Compute the following features for each entry in the adjacency matrix pf 'mol':
#         - bond type one-hot: categorical {1: single, 2: double, 3: triple, 4: aromatic}
#         - is conjugated: bool {0, 1}
#         - is in ring: bool {0, 1}
#         - graph distance: int
#         - euclidean distance: float
#     """
#     n_atoms = mol.GetNumAtoms()
#     features = np.zeros((n_atoms, n_atoms, N_EDGE_FEATURES-8))

#     # compute distance features
#     graph_dist = Chem.AllChem.GetDistanceMatrix(mol)

#     features[:,:,-1] = eucl_dist
#     features[:,:,-2] = graph_dist
#     for e in mol.GetBonds():
#         i = e.GetBeginAtomIdx()
#         j = e.GetEndAtomIdx()
#         dc_e_feats = dc.feat.graph_features.bond_features(e).astype(int)
#         features[i,j,:6], features[j,i,:6] = dc_e_feats, dc_e_feats
#     return features

In [660]:
def get_edge_features(mol, eucl_dist, row):
    """
    Compute the following features for each entry in the adjacency matrix pf 'mol':
        - bond type one-hot: categorical {1: single, 2: double, 3: triple, 4: aromatic}
        - is conjugated: bool {0, 1}
        - is in ring: bool {0, 1}
        - graph distance: int
        - euclidean distance: float
    """
    n_atoms, n_bonds = mol.GetNumAtoms(), mol.GetNumBonds()
    n_edge_features = (n_bonds + 1) * 2
    features = np.zeros((n_edge_features, N_EDGE_FEATURES))
    pairs_idx = np.zeros((n_edge_features, 2)) - 1
    
    graph_dist = Chem.AllChem.GetDistanceMatrix(mol)
    scalar_coupling_has_bond = False
    for n, e in enumerate(mol.GetBonds()):
        ix1 = 2 * n
        ix2 = (2 * n) + 1
        i = e.GetBeginAtomIdx()
        j = e.GetEndAtomIdx()
        dc_e_feats = dc.feat.graph_features.bond_features(e).astype(int)
        for ix in [ix1, ix2]:
            features[ix, :6] = dc_e_feats
            features[ix, 6] = graph_dist[i, j]
            features[ix, 7] = eucl_dist[i, j]
            if (row['atom_index_0'], row['atom_index_1']) in [(i, j), (j, i)]:
                features[ix, 8:] = (TYPES == row['type']).astype(float)
                scalar_coupling_has_bond = True
        pairs_idx[ix1] = i, j
        pairs_idx[ix2] = j, i
    if not scalar_coupling_has_bond:
        for ix in [-2, -1]:
            features[ix, 6] = graph_dist[row['atom_index_0'], row['atom_index_1']]
            features[ix, 7] = eucl_dist[row['atom_index_0'], row['atom_index_1']]
            features[ix, 8:] = (TYPES == row['type']).astype(float)
        pairs_idx[-2] = row['atom_index_0'], row['atom_index_1']
        pairs_idx[-1] = row['atom_index_1'], row['atom_index_0']
    return features[pairs_idx[:,0].argsort()], pairs_idx[pairs_idx[:,0].argsort()]

In [661]:
def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception(f"input {x} not in allowable set{allowable_set}:")
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def get_atom_features(mol):
    """
    Compute the following features for each atom in 'mol':
        - atom type: H, C, N, O, F (one-hot)
        - degree: 0, 1, 2, 3, 4 (one-hot)
        - implicit valence: 0, 1, 2, 3, 4, 5 (one-hot)
        - Hybridization: SP, SP2, SP3, SP3D, SP3D2 (one-hot)
        - is aromatic: bool {0, 1}
        - formal charge: int
        - num radical electrons: int
        - atomic number: int
    """
    n_atoms = mol.GetNumAtoms()
    features = np.zeros((n_atoms, N_ATOM_FEATURES-1))
    for a in mol.GetAtoms():
        a_feats = one_of_k_encoding(a.GetSymbol(), ['H', 'C', 'N', 'O', 'F']) \
            + one_of_k_encoding(a.GetDegree(), [0, 1, 2, 3, 4]) \
            + one_of_k_encoding(a.GetImplicitValence(), [0, 1, 2, 3, 4]) \
            + one_of_k_encoding_unk(a.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, 
                Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED]) \
            + [a.GetIsAromatic(), a.GetFormalCharge(), a.GetNumRadicalElectrons(), a.GetAtomicNum()]
        features[a.GetIdx(),:] = np.array(a_feats).astype(int)
    return features

In [662]:
# n_obs = 50000 # len(mols)
# atomic_features = np.zeros((n_obs, MAX_N_ATOMS, N_ATOM_FEATURES))
# edge_features = np.zeros((n_obs, MAX_N_ATOMS, MAX_N_ATOMS, N_EDGE_FEATURES))
# mask = np.zeros((n_mols, MAX_N_ATOMS))
# target = np.zeros(n_obs)
# keep = []
# mol_name = ''
# succesful_mols = list(mols.keys())
# types = train_df['type'].unique()
# for i in tqdm_notebook(range(n_obs)):
#     row = train_df.iloc[i,:]
#     new_mol_name = row['molecule_name']
#     if mol_name!=new_mol_name:
#         if new_mol_name in succesful_mols:
#             mol_name = new_mol_name
#             mol, dist_matrix = mols[mol_name], dist_matrices[mol_name]
#             n_atoms = mol.GetNumAtoms()
#         else:
#             continue
#     atomic_features[i, :n_atoms, :-1] = get_atom_features(mol)
#     atomic_features[i, row['atom_index_0'], -1] = 1.
#     atomic_features[i, row['atom_index_1'], -1] = 1.
#     edge_features[i, :n_atoms, :n_atoms, :-8] = get_edge_features(mol, dist_matrix)
#     edge_features[i, row['atom_index_0'], row['atom_index_1'], -8:] = (types == row['type']).astype(float)
#     mask[i,:n_atoms] = 1.
#     target[i] = row['scalar_coupling_constant']
#     keep.append(i)
# keep = np.array(keep)

In [664]:
n_obs = 50000 # len(mols)
atomic_features = np.zeros((n_obs, MAX_N_ATOMS, N_ATOM_FEATURES))
edge_features = np.zeros((n_obs, MAX_N_BONDS, N_EDGE_FEATURES))
pairs_idx = np.zeros((n_obs, MAX_N_BONDS, 2)) - 1
mask = np.zeros((n_obs, MAX_N_ATOMS))
edge_mask = np.zeros((n_obs, MAX_N_BONDS))
target = np.zeros(n_obs)
keep = []
mol_name = ''
succesful_mols = list(mols.keys())
for i in tqdm_notebook(range(n_obs)):
    row = train_df.iloc[i,:]
    new_mol_name = row['molecule_name']
    if mol_name!=new_mol_name:
        if new_mol_name in succesful_mols:
            mol_name = new_mol_name
            mol, dist_matrix = mols[mol_name], dist_matrices[mol_name]
            n_atoms = mol.GetNumAtoms()
            n_bonds = mol.GetNumBonds()
            n_edge_features = (n_bonds + 1) * 2
        else:
            continue
    atomic_features[i, :n_atoms, :-1] = get_atom_features(mol)
    atomic_features[i, row['atom_index_0'], -1] = 1.
    atomic_features[i, row['atom_index_1'], -1] = 1.
    edge_features[i, :n_edge_features, :], pairs_idx[i, :n_edge_features, :] = \
        get_edge_features(mol, dist_matrix, row)
    mask[i, :n_atoms], edge_mask[i, pairs_idx[i,:,0] != -1] = 1., 1.
    target[i] = row['scalar_coupling_constant']
    keep.append(i)
keep = np.array(keep)

HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))

In [665]:
pairs_idx_tmp_torch = torch.tensor(pairs_idx[-2,:,1], dtype=torch.long)

In [666]:
torch.tensor(atomic_features[-2])[pairs_idx_tmp_torch].size(), \
torch.tensor(atomic_features[-2])[pairs_idx_tmp_torch][torch.tensor(edge_mask[-2], dtype=torch.uint8)==True].size()

(torch.Size([58, 26]), torch.Size([40, 26]))

In [667]:
torch.tensor(atomic_features[-2])[pairs_idx_tmp_torch], \
torch.tensor(atomic_features[-2])[pairs_idx_tmp_torch][torch.tensor(edge_mask[-2], dtype=torch.uint8)==True]

(tensor([[0., 1., 0.,  ..., 0., 6., 0.],
         [1., 0., 0.,  ..., 0., 1., 0.],
         [1., 0., 0.,  ..., 0., 1., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64),
 tensor([[0., 1., 0.,  ..., 0., 6., 0.],
         [1., 0., 0.,  ..., 0., 1., 0.],
         [1., 0., 0.,  ..., 0., 1., 0.],
         ...,
         [0., 1., 0.,  ..., 0., 6., 0.],
         [0., 1., 0.,  ..., 0., 6., 0.],
         [0., 1., 0.,  ..., 0., 6., 0.]], dtype=torch.float64))

In [668]:
atomic_features = atomic_features[keep]
edge_features   = edge_features[keep]
pairs_idx       = pairs_idx[keep]
mask            = mask[keep]
edge_mask       = edge_mask[keep]
target          = target[keep]

In [669]:
print(f'atomic_features.shape\t: {atomic_features.shape}\nedge_features.shape\t: {edge_features.shape}')

atomic_features.shape	: (50000, 29, 26)
edge_features.shape	: (50000, 58, 16)


## Define MPNN Model

In [958]:
enn_args = dict(layers=[50, 50, 50, 50], act=nn.ReLU(True), dropout=[0.0, 0.0, 0.0, 0.0], batch_norm=False)
R_net_args = dict(layers=[200, 100], act=nn.ReLU(True), dropout=[0.0, 0.0], batch_norm=False)

In [959]:
def hidden_layer(n_in, n_out, batch_norm, dropout, act=None):
    layers = []
    layers.append(nn.Linear(n_in, n_out))
    if act: layers.append(act)
    if batch_norm: layers.append(nn.BatchNorm1d(n_out))
    if dropout != 0: layers.append(nn.Dropout(dropout))
    return layers

class FullyConnectedNet(nn.Module):
    
    def __init__(self, n_input, n_output, layers=[], act=nn.ReLU(True), dropout=[], batch_norm=False):
        super().__init__()
        sizes = [n_input] + layers + [n_output]
        layers_ = []
        for i, (n_in, n_out, dr) in enumerate(zip(sizes[:-1], sizes[1:], dropout+[0.0])):
            act_ = act if i < len(layers) else None
            batch_norm_ = batch_norm if i < len(layers) else False
            layers_ += hidden_layer(n_in, n_out, batch_norm_, dr, act_)      
        self.layers = nn.Sequential(*layers_)
    
    def forward(self, x):
        return self.layers(x)

In [960]:
class HiddenLSTMCell(nn.Module):
    """Implements the LSTM cell update described in the sec 4.2 of https://arxiv.org/pdf/1511.06391.pdf."""
    
    def __init__(self, n_h_out):
        """This LSTM cell takes no external 'x' inputs, but has a hidden state appended with the 
        readout from a content based attention mechanism. Therefore the hidden state is of a dimension
        that is two times the number of nodes in the set."""
        super().__init__()
        self.n_h_out, self.n_h = n_h_out, n_h_out * 2 
        self.w_h = nn.Parameter(torch.Tensor(self.n_h, n_h_out * 4))
        self.b = nn.Parameter(torch.Tensor(n_h_out * 4))
        self.init_weights()
    
    def init_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2: 
                nn.init.xavier_uniform_(p.data)
            else: 
                nn.init.zeros_(p.data)
                # initialize the forget gate bias to 1
                p.data[self.n_h_out:self.n_h_out*2] = torch.ones(self.n_h_out)
        
    def forward(self, h_prev, c_prev):
        """Takes previuos hidden and cell states as arguments and performs a single LSTM step using 
        no external input.
        """
        n_h_ = self.n_h_out # number of output hidden states
        # batch the computations into a single matrix multiplication
        gates = h_prev @ self.w_h + self.b
        i_g, f_g, g, o_g = (
            torch.sigmoid(gates[:, :n_h_]), # input
            torch.sigmoid(gates[:, n_h_:n_h_*2]), # forget
            torch.tanh(gates[:, n_h_*2:n_h_*3]),
            torch.sigmoid(gates[:, n_h_*3:]), # output
        )
        c = f_g * c_prev + i_g * g
        h = o_g * torch.tanh(c)
        return h, c

In [961]:
class Set2Set(nn.Module):
    """
    Adapted from: https://rusty1s.github.io/pytorch_geometric/build/html/_modules/torch_geometric\
        /nn/glob/set2set.html#Set2Set
    """
    def __init__(self, in_channels, proc_steps):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = 2 * in_channels
        self.proc_steps = proc_steps
        self.lstm = HiddenLSTMCell(self.in_channels)

    def forward(self, x, mask):
        """
        x - input tensor of shape (batch_size, n_nodes, in_channels)
        mask - integer tensor used to zero out nodes missing in a particualr graph 
            (not all graphs have 'n_nodes'). Is of shape (batch_size, n_nodes)
        """
        batch_size, n_nodes = mask.size(0), mask.size(1)
        batch_idx = torch.arange(0, batch_size).expand(n_nodes, batch_size).transpose(0, 1)
        h = torch.zeros(batch_size, self.in_channels)
        q_star = torch.zeros(batch_size, self.out_channels)
        mask = (mask.float() - 1) * 1e6
        for i in range(self.proc_steps):
            q, h = self.lstm(q_star, h)
            e = (x * q[batch_idx]).sum(dim=-1)
            # set masked nodes not to large negative energy (attention mask will convert this to 0)
            e += mask 
            a = F.softmax(e, dim=-1)
            # sum a*x over node dimension 
            r = torch.sum(a.unsqueeze(-1) * x, dim=1)
            q_star = torch.cat([q, r], dim=-1)
            
        return q_star
    
    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)

In [962]:
def segment_sum(data, segment_ids):
    """
    Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.

    :param data: A tensor whose segments are to be summed.
    :param segment_ids: The segment indices tensor.
    :return: A tensor of same data type as the data argument.
    """
    assert all([i in data.shape for i in segment_ids.shape]), "segment_ids.shape should be a prefix of data.shape"
    
    # segment_ids is a 1-D tensor repeat it to have the same shape as data
    if len(segment_ids.shape) == 1:
        s = torch.prod(torch.tensor(data.shape[1:])).long()
        segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])

    assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"

    num_segments = len(torch.unique(segment_ids))
    shape = [num_segments] + list(data.shape[1:])
    tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
    tensor = tensor.type(data.dtype)
    return tensor

class EdgeNetwork(nn.Module):
    
    def __init__(self, n_h, n_e, fully_connected_graph=False, use_master_node=True, net_args={}):
        super().__init__()
        self.n_e, self.n_h = n_e, n_h
        self.fully_connected_graph = fully_connected_graph
        self.use_master_node = use_master_node
        self.adjacency_net = FullyConnectedNet(n_e, n_h ** 2, **net_args)
        self.b_m = nn.Parameter(torch.Tensor(n_h)) # bias for the message function
        nn.init.zeros_(self.b_m)
    
    def forward(self, h, e, pairs_idx=None, edge_mask=None):
        """
        Compute message vector m_t given the previuos hidden state
        h_t-1 and edge features e. e_out represents the same edge 
        features as e_in with adjacency matrix transposed.
        - h is a collection of  hidden states of shape (batch_size, n_nodes, n_h)
        - e is a collection of edge features of shape 
            (batch_size, n_nodes, n_nodes, n_e) if fully_connected_graph
            else shape is (batch_size, n_edges, n_e).
        - pairs_idx: if self.fully_connected_graph = False this is a tensor
            of shape (batch_size, n_edges, 2) mapping atom indexes 
            (first column) to the other atom indexes they form a bond with
            (second column. 
        - edge_mask: if self.fully_connected_graph = False this is a tensor
            of shape (batch_size, n_edges) masking non present edges.
        """
        batch_size, n_nodes = h.size(0), h.size(1)
        
        # compute a
        e_reshaped = e.view(-1, self.n_e)
        a_vect = self.adjacency_net(e_reshaped) # dim(a_vect) = (batch_size * n_edges, n_h^2)
        if self.fully_connected_graph:
            a_tmp = a_vect.view(-1, n_nodes, n_nodes, self.n_h, self.n_h).transpose(2, 3)
            a = a_tmp.contiguous().view(-1, n_nodes * self.n_h, n_nodes * self.n_h)
            h_flat = h.view(batch_size, n_nodes * self.n_h, 1)
            m = torch.matmul(a, h_flat).view(batch_size * n_nodes, self.n_h)
        else:
            n_edges = e.size(1)
            edge_mask_ = edge_mask.type(torch.uint8)==True
            edge_mask_flat = edge_mask.view(-1).type(torch.uint8)==True
            
            a_mat = a_vect[edge_mask_flat].view(-1, self.n_h, self.n_h)
            h_flat = torch.cat([h[b,ix,:] for b, ix in enumerate(torch.unbind(pairs_idx[:,:,1]))])
            h_flat = h_flat[edge_mask_flat]
            ah = torch.einsum('bij,bjk->bik', h_flat.unsqueeze(1), a_mat).squeeze(1)
            
            n_nodes_per_g = pairs_idx[:,:,0].max(dim=1).values + 1
            unique_idx = pairs_idx[:,:,0] + (torch.cat([
                                                 torch.zeros(1, dtype=torch.long), 
                                                 n_nodes_per_g[:-1].cumsum(dim=0)
                                             ])).unsqueeze(-1).expand(-1, n_edges)
            m_stacked = segment_sum(ah, unique_idx[edge_mask_])
            
            m_per_g_lst = torch.split(m_stacked, n_nodes_per_g.tolist())
            m = torch.cat([F.pad(m_, pad=(0, 0, 0, n_nodes - n_nodes_)) 
                           for m_, n_nodes_ in zip(m_per_g_lst, n_nodes_per_g)]) 
            
        m += self.b_m

        return m.view(batch_size, n_nodes, self.n_h)

In [964]:
batch_size, n_nodes, n_h, n_e, n_edges = 20, MAX_N_ATOMS, 50, N_EDGE_FEATURES, MAX_N_BONDS
h = torch.randn(batch_size, n_nodes, n_h, dtype=torch.float)
e = torch.tensor(edge_features[:20,:,:], dtype=torch.float)
p_idx = torch.tensor(pairs_idx[:20,:,:], dtype=torch.long)
msk = torch.tensor(edge_mask[:20,:])
enn = EdgeNetwork(n_h, n_e, fully_connected_graph=False, use_master_node=True, net_args=enn_args)
enn(h, e, p_idx, msk)

tensor([[[ 8.2846e-01,  1.3722e+00,  1.3755e+00,  ...,  7.7376e-01,
           4.9090e-01, -7.3929e-01],
         [-7.6818e-02,  1.9346e-03,  3.5754e-01,  ..., -6.7729e-01,
           2.0827e-01,  4.1821e-02],
         [-6.7860e-02,  2.5336e-02,  3.3534e-01,  ..., -6.8375e-01,
           2.1888e-01, -3.5143e-03],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[-2.9843e-01,  3.0289e+00,  1.3437e+00,  ...,  5.8446e-01,
           9.8487e-01,  6.6117e-01],
         [ 1.5796e+00,  7.4156e-01,  1.0957e+00,  ...,  2.1461e-01,
           2.9847e-01, -3.5351e-01],
         [ 5.0861e-02,  9.0457e-01,  9.1477e-01,  ...,  1.7157e+00,
           1.3507e+00, -1.0601e+00],
         ...,
         [ 0.0000e+00,  0

In [965]:
e_reshaped = e.view(-1, n_e)
print(e_reshaped.size())
a_vect = enn.adjacency_net(e_reshaped)
print(a_vect.size())
edge_mask_ = msk.type(torch.uint8)==True
edge_mask_flat = msk.view(-1).type(torch.uint8)==True
print(edge_mask_.size(), edge_mask_flat.size())
a_mat = a_vect[edge_mask_flat].view(-1, n_h, n_h)
print(a_mat.size())
h_flat = torch.cat([h[batch,ix,:] for batch, ix in enumerate(torch.unbind(p_idx[:,:,1]))])
print(h_flat.size())
h_flat = h_flat[edge_mask_flat]
print(h_flat.size())
ah = torch.einsum('bij,bjk->bik', h_flat.unsqueeze(1), a_mat).squeeze(1)
print(ah.size())
n_nodes_per_g = p_idx[:,:,0].max(dim=1).values + 1
unique_idx = p_idx[:,:,0] + (torch.cat([
                                torch.zeros(1, dtype=torch.long), 
                                n_nodes_per_g[:-1].cumsum(dim=0)
                            ])).unsqueeze(-1).expand(-1, n_edges)
m_stacked = segment_sum(ah, unique_idx[edge_mask_])
print(m_stacked.size())

m_per_g_lst = torch.split(m_stacked, n_nodes_per_g.tolist())
print(m_per_g_lst[0].size())
m = torch.cat([F.pad(m_, pad=(0, 0, 0, n_nodes - n_nodes_)) for m_, n_nodes_ in zip(m_per_g_lst, n_nodes_per_g)]) 
print(m.size())
m += enn.b_m
print(m.size())
m.view(batch_size, n_nodes, n_h).size()

torch.Size([1160, 16])
torch.Size([1160, 2500])
torch.Size([20, 58]) torch.Size([1160])
torch.Size([164, 50, 50])
torch.Size([1160, 50])
torch.Size([164, 50])
torch.Size([164, 50])
torch.Size([91, 50])
torch.Size([5, 50])
torch.Size([580, 50])
torch.Size([580, 50])


torch.Size([20, 29, 50])

In [966]:
class GRUUpdate(nn.Module):
    def __init__(self, n_h):
        super().__init__()
        self.n_h = n_h
        self.gru = nn.GRUCell(n_h, n_h)
        
    def forward(self, m, h_prev, mask):
        """
        Update hidden state h.
        - h_prev is vector of hidden states of shape (batch_size, n_nodes, n_h)
        - m is vector of messages of shape (batch_size, n_nodes, n_h)
        - mask is used to  zero out nodes missing in a particualr graph (not all graphs 
            have 'n_nodes'). Is of shape (batch_size, n_nodes)
        """
        batch_size, n_nodes = h_prev.size(0), h_prev.size(1)
        h = self.gru(m.view(-1, self.n_h), h_prev.view(-1, self.n_h))
        return h.view(batch_size, n_nodes, self.n_h) * mask.unsqueeze(-1).expand(batch_size, n_nodes, self.n_h)

In [967]:
class Set2SetOutput(nn.Module):
    def __init__(self, n_x, n_h, proc_steps, net_args):
        super().__init__()
        self.n_h, self.n_x = n_h, n_x
        self.R_proj = nn.Linear(n_h + n_x, n_h)
        self.R_proc = Set2Set(n_h, proc_steps)
        self.R_write = FullyConnectedNet(2 * n_h, 1, **net_args)
    
    def forward(self, h, x, mask):
        """
        Update hidden state h.
        - h is vector of hidden states of shape (batch_size, n_nodes, n_h)
        - x is vector of input features of shape (batch_size, n_nodes, n_x)
        - mask is used to  zero out nodes missing in a particualr graph (not all graphs 
            have 'n_nodes'). Is of shape (batch_size, n_nodes)
        """
        batch_size, n_nodes = h.size(0), h.size(1)
        m = self.R_proj(torch.cat((h.view(-1, self.n_h), x.view(-1, self.n_x)), dim=1))
        q = self.R_proc(m.view(batch_size, n_nodes, self.n_h), mask) 
        y = self.R_write(q) # dim(q) = (batch_size, n_h)
        return y

In [968]:
class MPNN(nn.Module):
    def __init__(self, n_x, n_h, n_e, update_steps=3, proc_steps=10, enn_args={}, R_net_args={}):
        super().__init__()
        self.n_h, self.n_x = n_h, n_x
        self.M = EdgeNetwork(n_h, n_e, net_args=enn_args)
        self.U = GRUUpdate(n_h)
        self.R = Set2SetOutput(n_x, n_h, proc_steps, R_net_args)
        self.update_steps = update_steps
        
    def forward(self, x, e, mask, pairs_idx=None, edge_mask=None):
        h = F.pad(x, pad=(0, self.n_h - self.n_x))
        for t in range(self.update_steps):
            m = self.M(h, e, pairs_idx, edge_mask)
            h = self.U(m, h, mask)
        y = self.R(h, x, mask)
        return y

In [970]:
batch_size, n_nodes, n_h, n_e, n_x, n_edges = 20, MAX_N_ATOMS, 50, N_EDGE_FEATURES, N_ATOM_FEATURES, MAX_N_BONDS
x     = torch.tensor(atomic_features[:batch_size,:,:], dtype=torch.float)
e     = torch.tensor(edge_features[:batch_size,:,:], dtype=torch.float)
msk   = torch.tensor(mask[:batch_size,:], dtype=torch.float)
p_idx = torch.tensor(pairs_idx[:batch_size,:,:], dtype=torch.long)
e_msk = torch.tensor(edge_mask[:batch_size,:], dtype=torch.float)

mpnn = MPNN(n_x, n_h, n_e, update_steps=3, proc_steps=10, enn_args=enn_args, R_net_args=R_net_args)
print(mpnn)
print(mpnn(x, e, msk, p_idx, e_msk))
print(mpnn(x, e, msk, p_idx, e_msk).size())

MPNN(
  (M): EdgeNetwork(
    (adjacency_net): FullyConnectedNet(
      (layers): Sequential(
        (0): Linear(in_features=16, out_features=50, bias=True)
        (1): ReLU(inplace)
        (2): Linear(in_features=50, out_features=50, bias=True)
        (3): ReLU(inplace)
        (4): Linear(in_features=50, out_features=50, bias=True)
        (5): ReLU(inplace)
        (6): Linear(in_features=50, out_features=50, bias=True)
        (7): ReLU(inplace)
        (8): Linear(in_features=50, out_features=2500, bias=True)
      )
    )
  )
  (U): GRUUpdate(
    (gru): GRUCell(50, 50)
  )
  (R): Set2SetOutput(
    (R_proj): Linear(in_features=76, out_features=50, bias=True)
    (R_proc): Set2Set(50, 100)
    (R_write): FullyConnectedNet(
      (layers): Sequential(
        (0): Linear(in_features=100, out_features=200, bias=True)
        (1): ReLU(inplace)
        (2): Linear(in_features=200, out_features=100, bias=True)
        (3): ReLU(inplace)
        (4): Linear(in_features=100, out_fe

##  Fit MPNN

In [982]:
train_idx, val_idx = train_test_split(np.arange(n_obs), test_size=0.25, shuffle=True, random_state=100)
x_train, x_val = atomic_features[train_idx], atomic_features[val_idx]
e_train, e_val = edge_features[train_idx], edge_features[val_idx]
y_train, y_val = target[train_idx], target[val_idx]
mask_train, mask_val = mask[train_idx], mask[val_idx]
edge_mask_train, edge_mask_val = edge_mask[train_idx], edge_mask[val_idx]
pairs_idx_train, pairs_idx_val = pairs_idx[train_idx], pairs_idx[val_idx]

In [983]:
ss_target = StandardScaler()
y_train = ss_target.fit_transform(y_train.reshape(-1,1))
y_val = ss_target.transform(y_val.reshape(-1,1))

In [992]:
class MoleculeDataset(Dataset):
    
    def __init__(self, y, x, e, mask, pairs_idx=None, edge_mask=None):
        self.n = len(y)
        self.y = y.astype(np.float32)
        self.x = x.astype(np.float32)
        self.e = e.astype(np.float32)
        self.mask = mask.astype(np.float32)
        self.fully_connected_graphs = edge_mask is None
        if not self.fully_connected_graphs:
            self.pairs_idx = pairs_idx.astype(np.long)
            self.edge_mask = edge_mask.astype(np.float32)

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        if self.fully_connected_graphs:
            return (self.x[idx], self.e[idx], self.mask[idx]), self.y[idx]
        else:
            return (self.x[idx], self.e[idx], self.mask[idx], self.pairs_idx[idx], self.edge_mask[idx]), self.y[idx]

In [993]:
batch_size = 20

In [994]:
train_ds = MoleculeDataset(y_train, x_train, e_train, mask_train, pairs_idx_train, edge_mask_train)
val_ds   = MoleculeDataset(y_val, x_val, e_val, mask_val, pairs_idx_val, edge_mask_val)
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=8)
val_dl   = DataLoader(val_ds, batch_size, num_workers=8)
db = DataBunch(train_dl, val_dl)

In [1026]:
def group_mean_log_mae(y_true, y_pred, types):
    y_true, y_pred, types = y_true.cpu().numpy().ravel(), y_pred.cpu().numpy().ravel(), types.cpu().numpy().ravel()
    y_true = ss_target.mean_ + y_true * ss_target.scale_
    y_pred = ss_target.mean_ + y_pred * ss_target.scale_
    maes = pd.Series(y_true-y_pred).abs().groupby(types).mean()
    return np.log(maes).mean()

class GroupMeanLogMAE(Callback):
    _order = -20 #Needs to run before the recorder
    types_cidx = 2

    def __init__(self, learn, **kwargs): self.learn = learn
    def on_train_begin(self, **kwargs): self.learn.recorder.add_metric_names(['GroupMeanLogMAE'])
    def on_epoch_begin(self, **kwargs): self.input, self.output, self.target = [], [], []
    
    def on_batch_end(self, last_target, last_output, last_input, train, **kwargs):
        if not train:
            last_e = last_input[1]
            if len(last_e.size()) == 4: types = torch.nonzero(last_e[:,:,:,-8:])[:,-1]
            else: types = torch.nonzero(last_e[:,:,-8:])[::2,-1]
            self.input.append(types)
            self.output.append(last_output)
            self.target.append(last_target)
                
    def on_epoch_end(self, last_metrics, **kwargs):
        if (len(self.input) > 0) and (len(self.output) > 0):
            inputs = torch.cat(self.input)
            preds = torch.cat(self.output)
            target = torch.cat(self.target)
            metric = group_mean_log_mae(preds, target, inputs)
            return add_metrics(last_metrics, [metric])

def set_seed(seed=100):
    # python RNG
    random.seed(seed)

    # pytorch RNGs
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

    # numpy RNG
    np.random.seed(seed)

In [1029]:
wd = 1e-6
n_hidden = 50
enn_args = dict(layers=[50, 50, 50, 50], act=nn.ReLU(True), dropout=[0.0, 0.0, 0.0, 0.0], batch_norm=False)
R_net_args = dict(layers=[200, 100], act=nn.ReLU(True), dropout=[0.0, 0.0], batch_norm=False)

In [1030]:
set_seed(100)
model = MPNN(N_ATOM_FEATURES, n_hidden, N_EDGE_FEATURES, update_steps=1, proc_steps=5, 
             enn_args=enn_args, R_net_args=R_net_args)

In [1031]:
learn = Learner(db, model, metrics=[mean_absolute_error, root_mean_squared_error], 
                callback_fns=GroupMeanLogMAE, wd=wd, loss_func=root_mean_squared_error)

In [None]:
learn.lr_find(start_lr=1e-8, end_lr=1, num_it=100, stop_div=True)
learn.recorder.plot()

In [1032]:
learn.fit_one_cycle(10, max_lr=1e-3, callbacks=[SaveModelCallback(learn, every='improvement', mode='min',
                                                                  monitor='GroupMeanLogMAE',  name='mpnn')])

epoch,train_loss,valid_loss,mean_absolute_error,root_mean_squared_error,GroupMeanLogMAE,time
0,0.159601,0.156184,0.105536,0.156184,1.354229,03:05
1,0.127612,0.118614,0.0897,0.118614,1.153358,03:13
2,0.106659,0.100454,0.069477,0.100454,0.842295,03:18
3,0.100778,0.12122,0.082152,0.12122,0.880352,03:20
4,0.095718,0.092992,0.065019,0.092992,0.757166,03:22
5,0.082667,0.081624,0.056402,0.081624,0.651035,03:21
6,0.07521,0.07365,0.050919,0.07365,0.523686,03:13
7,0.0656,0.066438,0.045899,0.066438,0.37169,03:03
8,0.054816,0.059368,0.041857,0.059368,0.311036,03:01
9,0.053513,0.05725,0.040247,0.05725,0.267407,03:04


Better model found at epoch 0 with GroupMeanLogMAE value: 1.354228859685035.
Better model found at epoch 1 with GroupMeanLogMAE value: 1.1533581682915712.
Better model found at epoch 2 with GroupMeanLogMAE value: 0.8422953966309138.
Better model found at epoch 4 with GroupMeanLogMAE value: 0.7571660086835312.
Better model found at epoch 5 with GroupMeanLogMAE value: 0.6510346373432624.
Better model found at epoch 6 with GroupMeanLogMAE value: 0.5236864444576811.
Better model found at epoch 7 with GroupMeanLogMAE value: 0.3716901914886019.
Better model found at epoch 8 with GroupMeanLogMAE value: 0.31103578241766366.
Better model found at epoch 9 with GroupMeanLogMAE value: 0.26740657729393014.


In [1033]:
learn.fit_one_cycle(10, max_lr=2e-4, callbacks=[SaveModelCallback(learn, every='improvement', mode='min',
                                                                  monitor='GroupMeanLogMAE',  name='mpnn')])

epoch,train_loss,valid_loss,mean_absolute_error,root_mean_squared_error,GroupMeanLogMAE,time
0,0.057422,0.057541,0.040643,0.057541,0.286319,02:59
1,0.055529,0.059438,0.041918,0.059438,0.320873,03:05
2,0.059378,0.059227,0.042514,0.059227,0.336825,03:11
3,0.057274,0.055405,0.039596,0.055405,0.281443,03:09
4,0.053913,0.057279,0.039944,0.057279,0.303203,03:10
5,0.052195,0.056035,0.039194,0.056035,0.222439,03:15
6,0.048742,0.052346,0.036425,0.052346,0.167444,03:24
7,0.046596,0.050768,0.035409,0.050768,0.15019,03:28
8,0.048443,0.049717,0.034626,0.049717,0.131897,03:27
9,0.047076,0.049524,0.034494,0.049524,0.124534,03:22


Better model found at epoch 0 with GroupMeanLogMAE value: 0.28631927272693414.
Better model found at epoch 3 with GroupMeanLogMAE value: 0.28144344014824785.
Better model found at epoch 5 with GroupMeanLogMAE value: 0.22243898377855462.
Better model found at epoch 6 with GroupMeanLogMAE value: 0.1674443047489782.
Better model found at epoch 7 with GroupMeanLogMAE value: 0.1501904718669072.
Better model found at epoch 8 with GroupMeanLogMAE value: 0.1318967869551606.
Better model found at epoch 9 with GroupMeanLogMAE value: 0.1245335506225918.


In [1034]:
learn.fit_one_cycle(10, max_lr=4e-5, callbacks=[SaveModelCallback(learn, every='improvement', mode='min',
                                                                  monitor='GroupMeanLogMAE',  name='mpnn')])

epoch,train_loss,valid_loss,mean_absolute_error,root_mean_squared_error,GroupMeanLogMAE,time
0,0.048812,0.049702,0.034633,0.049702,0.127906,03:07
1,0.046267,0.049696,0.03467,0.049696,0.131523,03:35
2,0.04523,0.050078,0.03513,0.050078,0.150454,03:37
3,0.048768,0.04921,0.034214,0.04921,0.121127,03:34
4,0.04683,0.049385,0.034457,0.049385,0.127631,03:30
5,0.047618,0.049113,0.034081,0.049113,0.117352,03:27
6,0.044582,0.048713,0.033777,0.048713,0.102756,03:33
7,0.04567,0.048842,0.033876,0.048842,0.107339,03:09
8,0.048197,0.048487,0.033572,0.048487,0.096942,03:04
9,0.045038,0.048442,0.033557,0.048442,0.096955,03:01


Better model found at epoch 0 with GroupMeanLogMAE value: 0.12790607972417514.
Better model found at epoch 3 with GroupMeanLogMAE value: 0.12112732865715103.
Better model found at epoch 5 with GroupMeanLogMAE value: 0.11735198926441456.
Better model found at epoch 6 with GroupMeanLogMAE value: 0.10275615135438726.
Better model found at epoch 8 with GroupMeanLogMAE value: 0.09694183179473853.


In [1035]:
learn.fit_one_cycle(10, max_lr=1e-5, callbacks=[SaveModelCallback(learn, every='improvement', mode='min',
                                                                  monitor='GroupMeanLogMAE',  name='mpnn')])

epoch,train_loss,valid_loss,mean_absolute_error,root_mean_squared_error,GroupMeanLogMAE,time
0,0.0457,0.048459,0.033606,0.048459,0.098804,03:03
1,0.04591,0.04854,0.03355,0.04854,0.096201,03:10
2,0.046046,0.048621,0.033625,0.048621,0.095609,03:14
3,0.047394,0.048486,0.033614,0.048486,0.103721,03:20
4,0.048073,0.048659,0.033661,0.048659,0.096728,03:22
5,0.046877,0.048345,0.033501,0.048345,0.094218,03:09
6,0.046217,0.048304,0.033401,0.048304,0.09269,03:22
7,0.046137,0.048236,0.033399,0.048236,0.093047,03:24
8,0.047683,0.048207,0.033332,0.048207,0.090394,03:09
9,0.046603,0.048204,0.033339,0.048204,0.090764,03:08


Better model found at epoch 0 with GroupMeanLogMAE value: 0.09880390561805735.
Better model found at epoch 1 with GroupMeanLogMAE value: 0.09620074511375563.
Better model found at epoch 2 with GroupMeanLogMAE value: 0.09560922042310946.
Better model found at epoch 5 with GroupMeanLogMAE value: 0.09421753289340161.
Better model found at epoch 6 with GroupMeanLogMAE value: 0.09268977237219678.
Better model found at epoch 8 with GroupMeanLogMAE value: 0.09039376142404532.


In [None]:
pred, _ = learn.get_preds()
pred_test, _ = learn.get_preds(DatasetType.Test)