# make dataset

In [None]:
import os
import torch
import pickle
import collections
import math
import pandas as pd
import numpy as np
import networkx as nx
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from torch.utils import data
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Batch
from itertools import repeat, product, chain

# add for grover
from argparse import Namespace
from typing import List, Tuple, Union
import pickle

In [None]:
# allowable node and edge features
allowable_features = {
    'possible_atomic_num_list' : list(range(0, 99)),
    'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
    'possible_chirality_list' : [
        Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        Chem.rdchem.ChiralType.CHI_OTHER
    ],
    'possible_hybridization_list' : [
        Chem.rdchem.HybridizationType.S,
        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
    ],
    'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8],
    'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6],
    'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'possible_bonds' : [
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE,
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ],
    'possible_bond_dirs' : [ # only for double bond stereo information
        Chem.rdchem.BondDir.NONE,
        Chem.rdchem.BondDir.ENDUPRIGHT,
        Chem.rdchem.BondDir.ENDDOWNRIGHT
    ]
}

def mol_to_graph_data_obj_simple(mol):
    """
    Converts rdkit mol object to graph Data object required by the pytorch
    geometric package. NB: Uses simplified atom and bond features, and represent
    as indices
    :param mol: rdkit mol object
    :return: graph data object with the attributes: x, edge_index, edge_attr
    """
    # atoms
    num_atom_features = 2   # atom type,  chirality tag
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_feature = [allowable_features['possible_atomic_num_list'].index(
            atom.GetAtomicNum())] + [allowable_features[
            'possible_chirality_list'].index(atom.GetChiralTag())]
        atom_features_list.append(atom_feature)
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

    # bonds
    num_bond_features = 2   # bond type, bond direction
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_feature = [allowable_features['possible_bonds'].index(
                bond.GetBondType())] + [allowable_features[
                                            'possible_bond_dirs'].index(
                bond.GetBondDir())]
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = torch.tensor(np.array(edge_features_list),
                                 dtype=torch.long)
    else:   # mol has no bonds
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    return data

def graph_data_obj_to_mol_simple(data_x, data_edge_index, data_edge_attr):
    """
    Convert pytorch geometric data obj to rdkit mol object. NB: Uses simplified
    atom and bond features, and represent as indices.
    :param: data_x:
    :param: data_edge_index:
    :param: data_edge_attr
    :return:
    """
    mol = Chem.RWMol()

    # atoms
    atom_features = data_x.cpu().numpy()
    num_atoms = atom_features.shape[0]
    for i in range(num_atoms):
        atomic_num_idx, chirality_tag_idx = atom_features[i]
        atomic_num = allowable_features['possible_atomic_num_list'][atomic_num_idx]
        chirality_tag = allowable_features['possible_chirality_list'][chirality_tag_idx]
        atom = Chem.Atom(atomic_num)
        atom.SetChiralTag(chirality_tag)
        mol.AddAtom(atom)

    # bonds
    edge_index = data_edge_index.cpu().numpy()
    edge_attr = data_edge_attr.cpu().numpy()
    num_bonds = edge_index.shape[1]
    for j in range(0, num_bonds, 2):
        begin_idx = int(edge_index[0, j])
        end_idx = int(edge_index[1, j])
        bond_type_idx, bond_dir_idx = edge_attr[j]
        bond_type = allowable_features['possible_bonds'][bond_type_idx]
        bond_dir = allowable_features['possible_bond_dirs'][bond_dir_idx]
        mol.AddBond(begin_idx, end_idx, bond_type)
        # set bond direction
        new_bond = mol.GetBondBetweenAtoms(begin_idx, end_idx)
        new_bond.SetBondDir(bond_dir)

    # Chem.SanitizeMol(mol) # fails for COC1=CC2=C(NC(=N2)[S@@](=O)CC2=NC=C(
    # C)C(OC)=C2C)C=C1, when aromatic bond is possible
    # when we do not have aromatic bonds
    # Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)

    return mol

def graph_data_obj_to_nx_simple(data):
    """
    Converts graph Data object required by the pytorch geometric package to
    network x data object. NB: Uses simplified atom and bond features,
    and represent as indices. NB: possible issues with recapitulating relative
    stereochemistry since the edges in the nx object are unordered.
    :param data: pytorch geometric Data object
    :return: network x object
    """
    G = nx.Graph()

    # atoms
    atom_features = data.x.cpu().numpy()
    num_atoms = atom_features.shape[0]
    for i in range(num_atoms):
        atomic_num_idx, chirality_tag_idx = atom_features[i]
        G.add_node(i, atom_num_idx=atomic_num_idx, chirality_tag_idx=chirality_tag_idx)
        pass

    # bonds
    edge_index = data.edge_index.cpu().numpy()
    edge_attr = data.edge_attr.cpu().numpy()
    num_bonds = edge_index.shape[1]
    for j in range(0, num_bonds, 2):
        begin_idx = int(edge_index[0, j])
        end_idx = int(edge_index[1, j])
        bond_type_idx, bond_dir_idx = edge_attr[j]
        if not G.has_edge(begin_idx, end_idx):
            G.add_edge(begin_idx, end_idx, bond_type_idx=bond_type_idx,
                       bond_dir_idx=bond_dir_idx)

    return G

def nx_to_graph_data_obj_simple(G):
    """
    Converts nx graph to pytorch geometric Data object. Assume node indices
    are numbered from 0 to num_nodes - 1. NB: Uses simplified atom and bond
    features, and represent as indices. NB: possible issues with
    recapitulating relative stereochemistry since the edges in the nx
    object are unordered.
    :param G: nx graph obj
    :return: pytorch geometric Data object
    """
    # atoms
    num_atom_features = 2  # atom type,  chirality tag
    atom_features_list = []
    for _, node in G.nodes(data=True):
        atom_feature = [node['atom_num_idx'], node['chirality_tag_idx']]
        atom_features_list.append(atom_feature)
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

    # bonds
    num_bond_features = 2  # bond type, bond direction
    if len(G.edges()) > 0:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for i, j, edge in G.edges(data=True):
            edge_feature = [edge['bond_type_idx'], edge['bond_dir_idx']]
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = torch.tensor(np.array(edge_features_list),
                                 dtype=torch.long)
    else:   # mol has no bonds
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    return data

def get_gasteiger_partial_charges(mol, n_iter=12):
    """
    Calculates list of gasteiger partial charges for each atom in mol object.
    :param mol: rdkit mol object
    :param n_iter: number of iterations. Default 12
    :return: list of computed partial charges for each atom.
    """
    Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter,
                                                  throwOnParamFailure=True)
    partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in
                       mol.GetAtoms()]
    return partial_charges

def create_standardized_mol_id(smiles):
    """

    :param smiles:
    :return: inchi
    """
    if check_smiles_validity(smiles):
        # remove stereochemistry
        smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
                                     isomericSmiles=False)
        mol = AllChem.MolFromSmiles(smiles)
        if mol != None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
            if '.' in smiles: # if multiple species, pick largest molecule
                mol_species_list = split_rdkit_mol_obj(mol)
                largest_mol = get_largest_mol(mol_species_list)
                inchi = AllChem.MolToInchi(largest_mol)
            else:
                inchi = AllChem.MolToInchi(mol)
            return inchi
        else:
            return
    else:
        return

def create_circular_fingerprint(mol, radius, size, chirality):
    """

    :param mol:
    :param radius:
    :param size:
    :param chirality:
    :return: np array of morgan fingerprint
    """
    fp = GetMorganFingerprintAsBitVect(mol, radius,
                                       nBits=size, useChirality=chirality)
    return np.array(fp)

def check_smiles_validity(smiles):
    try:
        m = Chem.MolFromSmiles(smiles)
        if m:
            return True
        else:
            return False
    except:
        return False

def split_rdkit_mol_obj(mol):
    """
    Split rdkit mol object containing multiple species or one species into a
    list of mol objects or a list containing a single object respectively
    :param mol:
    :return:
    """
    smiles = AllChem.MolToSmiles(mol, isomericSmiles=True)
    smiles_list = smiles.split('.')
    mol_species_list = []
    for s in smiles_list:
        if check_smiles_validity(s):
            mol_species_list.append(AllChem.MolFromSmiles(s))
    return mol_species_list

def get_largest_mol(mol_list):
    """
    Given a list of rdkit mol objects, returns mol object containing the
    largest num of atoms. If multiple containing largest num of atoms,
    picks the first one
    :param mol_list:
    :return:
    """
    num_atoms_list = [len(m.GetAtoms()) for m in mol_list]
    largest_mol_idx = num_atoms_list.index(max(num_atoms_list))
    return mol_list[largest_mol_idx]

#add for grover feature
ATOM_FEATURES = {
    'atomic_num': list(range(100)),
    'degree': [0, 1, 2, 3, 4, 5],
    'formal_charge': [-1, -2, 1, 2, 0],
    'chiral_tag': [0, 1, 2, 3],
    'num_Hs': [0, 1, 2, 3, 4],
    'hybridization': [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2
    ],
}
atom_fdim = 151
bond_fdim = 165

In [None]:
def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:
    """
    Creates a one-hot encoding.

    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the value in a list of length len(choices) + 1.
    If value is not in the list of choices, then the final element in the encoding is 1.
    """
    encoding = [0] * (len(choices) + 1)
    if min(choices) < 0:
        index = value
    else:
        index = choices.index(value) if value in choices else -1
    encoding[index] = 1

    return encoding

def atom_features(atom: Chem.rdchem.Atom, hydrogen_acceptor_match, hydrogen_donor_match, acidic_match, basic_match, ring_info) -> List[Union[bool, int, float]]:
    """
    Builds a feature vector for an atom.

    :param atom: An RDKit atom.
    :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
    :return: A list containing the atom features.
    """
    features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \
               onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
               onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \
               onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \
               onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
               onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
               [1 if atom.GetIsAromatic() else 0] + \
               [atom.GetMass() * 0.01]
    atom_idx = atom.GetIdx()
    features = features + \
               onek_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
               [atom_idx in hydrogen_acceptor_match] + \
               [atom_idx in hydrogen_donor_match] + \
               [atom_idx in acidic_match] + \
               [atom_idx in basic_match] + \
               [ring_info.IsAtomInRingOfSize(atom_idx, 3),
                ring_info.IsAtomInRingOfSize(atom_idx, 4),
                ring_info.IsAtomInRingOfSize(atom_idx, 5),
                ring_info.IsAtomInRingOfSize(atom_idx, 6),
                ring_info.IsAtomInRingOfSize(atom_idx, 7),
                ring_info.IsAtomInRingOfSize(atom_idx, 8)]
    return features

def bond_features(bond: Chem.rdchem.Bond
                  ) -> List[Union[bool, int, float]]:
    """
    Builds a feature vector for a bond.

    :param bond: A RDKit bond.
    :return: A list containing the bond features.
    """

    if bond is None:
        fbond = [1] + [0] * (BOND_FDIM - 1)
    else:
        bt = bond.GetBondType()
        fbond = [
            0,  # bond is not None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
        fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
    return fbond

def mol_to_graph_data_obj_grover(mol):
    #mol = Chem.MolFromSmiles(mol)
    hydrogen_donor = Chem.MolFromSmarts("[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
    hydrogen_acceptor = Chem.MolFromSmarts(
        "[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),"
        "n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]")
    acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
    basic = Chem.MolFromSmarts(
        "[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);"
        "!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]")

    hydrogen_donor_match = sum(mol.GetSubstructMatches(hydrogen_donor), ())
    hydrogen_acceptor_match = sum(mol.GetSubstructMatches(hydrogen_acceptor), ())
    acidic_match = sum(mol.GetSubstructMatches(acidic), ())
    basic_match = sum(mol.GetSubstructMatches(basic), ())
    ring_info = mol.GetRingInfo()

    n_atoms = mol.GetNumAtoms()
    
    f_atoms = []
    for _, atom in enumerate(mol.GetAtoms()):
        f_atoms.append(atom_features(atom, hydrogen_donor_match, hydrogen_acceptor_match, acidic_match, basic_match, ring_info))
    f_atoms = [f_atoms[i] for i in range(n_atoms)]
    
    f_bonds = []
    bond_list = []
    for a1 in range(n_atoms):
        for a2 in range(a1 + 1, n_atoms):
            bond = mol.GetBondBetweenAtoms(a1, a2)

            if bond is None:
                continue

            f_bond = bond_features(bond)

            # Always treat the bond as directed.
            f_bonds.append(f_atoms[a1] + f_bond)
            bond_list.append([a1, a2])
            f_bonds.append(f_atoms[a2] + f_bond)
            bond_list.append([a2, a1])
    
    #data = [f_atoms, bond_list, f_bonds]
    data = Data(x=f_atoms, edge_index=bond_list, edge_attr=f_bonds)
    return data

class MoleculeDataset_grover(InMemoryDataset):
    def __init__(self,
                 root,
                 #data = None,
                 #slices = None,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None,
                 dataset='zinc250k',
                 empty=False):
        """
        Adapted from qm9.py. Disabled the download functionality
        :param root: directory of the dataset, containing a raw and processed
        dir. The raw dir should contain the file containing the smiles, and the
        processed dir can either empty or a previously processed file
        :param dataset: name of the dataset. Currently only implemented for
        zinc250k, chembl_with_labels, tox21, hiv, bace, bbbp, clintox, esol,
        freesolv, lipophilicity, muv, pcba, sider, toxcast
        :param empty: if True, then will not load any data obj. For
        initializing empty dataset
        """
        self.dataset = dataset
        self.root = root
        self.smiles = []

        super(MoleculeDataset_grover, self).__init__(root, transform, pre_transform,
                                                 pre_filter)
        self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
        if not empty:
            self.data, self.slices = torch.load(self.processed_paths[0])

    def get(self, idx):
        data = Data()
        for key in self.data.keys:
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[data.__cat_dim__(key, item)] = slice(slices[idx],
                                                    slices[idx + 1])
            data[key] = item[s]
        return data

    @property
    def raw_file_names(self):
        file_name_list = os.listdir(self.raw_dir)
        # assert len(file_name_list) == 1     # currently assume we have a
        # # single raw file
        return file_name_list

    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'

    def download(self):
        raise NotImplementedError('Must indicate valid location of raw data. '
                                  'No download allowed')
        
    def process(self):
        data_smiles_list = []
        data_list = []
        
        smiles_list, rdkit_mol_objs, labels = \
            _load_other_dataset(self.raw_paths[0])
        for i in range(len(smiles_list)):
            #print(i)
            rdkit_mol = rdkit_mol_objs[i]
            if rdkit_mol != None:
                data = mol_to_graph_data_obj_grover(rdkit_mol)
                # manually add mol id
                data.id = torch.tensor(
                    [i])  # id here is the index of the mol in
                # the dataset
                data.y = torch.tensor(labels[i, :])
                data_list.append(data)
                data_smiles_list.append(smiles_list[i])

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        # write data_smiles_list in processed paths
        data_smiles_series = pd.Series(data_smiles_list)
        data_smiles_series.to_csv(os.path.join(self.processed_dir,
                                               'smiles.csv'), index=False,
                                  header=False)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        
        self.smiles = smiles_list
        self.data_list = data_list



In [None]:
def _load_other_dataset(input_path):
    """
    this is for loading other datasets
    """
    df = pd.read_csv(input_path)
    smiles_list = df.iloc[:,0]
    rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
    tasks = list(df.columns[1:])
    labels = df[tasks]
    labels = labels.replace(0, -1)
    labels = labels.replace(0.0, -1)
    labels = labels.fillna(0)
    assert len(smiles_list) == len(rdkit_mol_objs_list)
    assert len(smiles_list) == len(labels)
    return smiles_list, rdkit_mol_objs_list, labels.values

class MoleculeDataset_other(InMemoryDataset):
    def __init__(self,
                 root,
                 #data = None,
                 #slices = None,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None,
                 dataset='zinc250k',
                 empty=False):
        """
        Adapted from qm9.py. Disabled the download functionality
        :param root: directory of the dataset, containing a raw and processed
        dir. The raw dir should contain the file containing the smiles, and the
        processed dir can either empty or a previously processed file
        :param dataset: name of the dataset. Currently only implemented for
        zinc250k, chembl_with_labels, tox21, hiv, bace, bbbp, clintox, esol,
        freesolv, lipophilicity, muv, pcba, sider, toxcast
        :param empty: if True, then will not load any data obj. For
        initializing empty dataset
        """
        self.dataset = dataset
        self.root = root
        self.smiles = []

        super(MoleculeDataset_other, self).__init__(root, transform, pre_transform,
                                                 pre_filter)
        self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
        if not empty:
            self.data, self.slices = torch.load(self.processed_paths[0])

    def get(self, idx):
        data = Data()
        for key in self.data.keys:
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[data.__cat_dim__(key, item)] = slice(slices[idx],
                                                    slices[idx + 1])
            data[key] = item[s]
        return data

    @property
    def raw_file_names(self):
        file_name_list = os.listdir(self.raw_dir)
        # assert len(file_name_list) == 1     # currently assume we have a
        # # single raw file
        return file_name_list

    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'

    def download(self):
        raise NotImplementedError('Must indicate valid location of raw data. '
                                  'No download allowed')
        
    def process(self):
        data_smiles_list = []
        data_list = []
        
        smiles_list, rdkit_mol_objs, labels = \
            _load_other_dataset(f'{self.root}/raw/{self.dataset}.csv')
        for i in range(len(smiles_list)):
            #print(i)
            rdkit_mol = rdkit_mol_objs[i]
            if rdkit_mol != None:
                data = mol_to_graph_data_obj_simple(rdkit_mol)
                # manually add mol id
                data.id = torch.tensor(
                    [i])  # id here is the index of the mol in
                # the dataset
                data.y = torch.tensor(labels[i, :])
                data_list.append(data)
                data_smiles_list.append(smiles_list[i])

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        # write data_smiles_list in processed paths
        data_smiles_series = pd.Series(data_smiles_list)
        data_smiles_series.to_csv(os.path.join(self.processed_dir,
                                               'smiles.csv'), index=False,
                                  header=False)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        
        self.smiles = smiles_list
        self.data_list = data_list

In [None]:
name = 'tg407'
ds = MoleculeDataset_other(f'dataset_grover/{name}', dataset=name)

## origin dataset

In [None]:
data_smiles_list = []
data_list = []

smiles_list, rdkit_mol_objs, labels = \
    _load_other_dataset('dataset_grover/tox21/raw/tox21.csv')




<function torch_geometric.data.in_memory_dataset.InMemoryDataset.collate(data_list: List[torch_geometric.data.data.Data]) -> Tuple[torch_geometric.data.data.Data, Union[Dict[str, torch.Tensor], NoneType]]>

In [None]:
for i in range(len(smiles_list)):
    #print(i)
    rdkit_mol = rdkit_mol_objs[i]
    if rdkit_mol != None:
        data = mol_to_graph_data_obj_simple(rdkit_mol)
        # manually add mol id
        data.id = torch.tensor(
            [i])  # id here is the index of the mol in
        # the dataset
        data.y = torch.tensor(labels[i, :])
        data_list.append(data)
        data_smiles_list.append(smiles_list[i])

# write data_smiles_list in processed paths
data_smiles_series = pd.Series(data_smiles_list)

data_new, slices = MoleculeDataset_other.collate(data_list)

In [None]:
data

Data(x=[25, 2], edge_index=[2, 58], edge_attr=[58, 2], id=[1], y=[12])

In [None]:
data_new.edge_index

tensor([[ 0,  1,  1,  ..., 14, 24, 16],
        [ 1,  0,  2,  ..., 19, 16, 24]])

In [None]:
slices

defaultdict(dict,
            {'x': tensor([     0,     16,     31,  ..., 145414, 145434, 145459]),
             'edge_index': tensor([     0,     34,     66,  ..., 302086, 302132, 302190]),
             'edge_attr': tensor([     0,     34,     66,  ..., 302086, 302132, 302190]),
             'id': tensor([   0,    1,    2,  ..., 7829, 7830, 7831]),
             'y': tensor([    0,    12,    24,  ..., 93948, 93960, 93972])})

In [None]:
mol = rdkit_mol_objs[0]
num_atom_features = 2   # atom type,  chirality tag
atom_features_list = []
for atom in mol.GetAtoms():
    atom_feature = [allowable_features['possible_atomic_num_list'].index(
        atom.GetAtomicNum())] + [allowable_features[
        'possible_chirality_list'].index(atom.GetChiralTag())]
    atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

# bonds
num_bond_features = 2   # bond type, bond direction
if len(mol.GetBonds()) > 0: # mol has bonds
    edges_list = []
    edge_features_list = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_feature = [allowable_features['possible_bonds'].index(
            bond.GetBondType())] + [allowable_features[
                                        'possible_bond_dirs'].index(
            bond.GetBondDir())]
        edges_list.append((i, j))
        edge_features_list.append(edge_feature)
        edges_list.append((j, i))
        edge_features_list.append(edge_feature)

    # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
    edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

    # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
    edge_attr = torch.tensor(np.array(edge_features_list),
                             dtype=torch.long)
else:   # mol has no bonds
    edge_index = torch.empty((2, 0), dtype=torch.long)
    edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)

In [None]:
x.shape

torch.Size([16, 2])

In [None]:
torch.tensor(f_atoms).shape

torch.Size([16, 151])

In [None]:
edge_index.shape

torch.Size([2, 34])

In [None]:
torch.tensor(np.array(bond_list).T, dtype=torch.long).shape

torch.Size([2, 34])

In [None]:
edge_attr.shape

torch.Size([34, 2])

In [None]:
torch.tensor(np.array(f_bonds), dtype=torch.long).shape

torch.Size([34, 185])

## grover dataset

In [None]:
data_smiles_list2 = []
data_list2 = []

In [None]:
for i in range(len(smiles_list)):
    #print(i)
    rdkit_mol2 = rdkit_mol_objs[i]
    if rdkit_mol2 != None:
        data2 = mol_to_graph_data_obj_grover(rdkit_mol2)
        # manually add mol id
        data2.id = torch.tensor([i])  # id here is the index of the mol in
        # the dataset
        data2.y = torch.tensor(labels[i, :])
        data_list2.append(data2)
        data_smiles_list2.append(smiles_list[i])

# write data_smiles_list in processed paths
data_smiles_series2 = pd.Series(data_smiles_list2)

data2_new, slices2 = MoleculeDataset_grover.collate(data_list2)

In [None]:
hydrogen_donor = Chem.MolFromSmarts("[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
hydrogen_acceptor = Chem.MolFromSmarts(
    "[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),"
    "n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]")
acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
basic = Chem.MolFromSmarts(
    "[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);"
    "!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]")

hydrogen_donor_match = sum(mol.GetSubstructMatches(hydrogen_donor), ())
hydrogen_acceptor_match = sum(mol.GetSubstructMatches(hydrogen_acceptor), ())
acidic_match = sum(mol.GetSubstructMatches(acidic), ())
basic_match = sum(mol.GetSubstructMatches(basic), ())
ring_info = mol.GetRingInfo()

n_atoms = mol.GetNumAtoms()

f_atoms = []
for _, atom in enumerate(mol.GetAtoms()):
    f_atoms.append(atom_features(atom, hydrogen_donor_match, hydrogen_acceptor_match, acidic_match, basic_match, ring_info))
f_atoms = [f_atoms[i] for i in range(n_atoms)]

f_bonds = []
bond_list = []
for a1 in range(n_atoms):
    for a2 in range(a1 + 1, n_atoms):
        bond = mol.GetBondBetweenAtoms(a1, a2)

        if bond is None:
            continue

        f_bond = bond_features(bond)

        # Always treat the bond as directed.
        f_bonds.append(f_atoms[a1] + f_bond)
        bond_list.append([a1, a2])
        f_bonds.append(f_atoms[a2] + f_bond)
        bond_list.append([a2, a1])

In [None]:
data2

Data(x=[25], edge_index=[58], edge_attr=[58], id=[1], y=[12])

In [None]:
data2_new.edge_index

[[[0, 1],
  [1, 0],
  [1, 2],
  [2, 1],
  [2, 3],
  [3, 2],
  [3, 4],
  [4, 3],
  [3, 15],
  [15, 3],
  [4, 5],
  [5, 4],
  [5, 6],
  [6, 5],
  [6, 7],
  [7, 6],
  [6, 14],
  [14, 6],
  [7, 8],
  [8, 7],
  [8, 9],
  [9, 8],
  [8, 13],
  [13, 8],
  [9, 10],
  [10, 9],
  [9, 11],
  [11, 9],
  [9, 12],
  [12, 9],
  [13, 14],
  [14, 13],
  [14, 15],
  [15, 14]],
 [[0, 1],
  [1, 0],
  [1, 2],
  [2, 1],
  [2, 3],
  [3, 2],
  [2, 13],
  [13, 2],
  [3, 4],
  [4, 3],
  [3, 5],
  [5, 3],
  [5, 6],
  [6, 5],
  [6, 7],
  [7, 6],
  [6, 13],
  [13, 6],
  [7, 8],
  [8, 7],
  [7, 12],
  [12, 7],
  [8, 9],
  [9, 8],
  [9, 10],
  [10, 9],
  [10, 11],
  [11, 10],
  [11, 12],
  [12, 11],
  [13, 14],
  [14, 13]],
 [[0, 1],
  [1, 0],
  [1, 2],
  [2, 1],
  [2, 3],
  [3, 2],
  [2, 4],
  [4, 2],
  [2, 19],
  [19, 2],
  [4, 5],
  [5, 4],
  [5, 6],
  [6, 5],
  [6, 7],
  [7, 6],
  [6, 19],
  [19, 6],
  [7, 8],
  [8, 7],
  [7, 16],
  [16, 7],
  [8, 9],
  [9, 8],
  [9, 10],
  [10, 9],
  [10, 11],
  [11, 10],
  [10,

In [None]:
slices2

defaultdict(dict,
            {'x': tensor([   0,    1,    2,  ..., 7829, 7830, 7831]),
             'edge_index': tensor([   0,    1,    2,  ..., 7829, 7830, 7831]),
             'edge_attr': tensor([   0,    1,    2,  ..., 7829, 7830, 7831]),
             'id': tensor([   0,    1,    2,  ..., 7829, 7830, 7831]),
             'y': tensor([    0,    12,    24,  ..., 93948, 93960, 93972])})

# model

In [None]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros
atom_fdim = 151
bond_fdim = 165

In [None]:
class GINConv_grover(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """
    def __init__(self, emb_dim, aggr = "add"):
        super(GINConv_grover, self).__init__()
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.edge_embedding = torch.nn.Linear(bond_fdim, emb_dim)

        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), bond_fdim)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding(edge_attr)

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)


In [None]:
class GNN_grover(torch.nn.Module):
    """
    

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        JK (str): last, concat, max or sum.
        max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat

    Output:
        node representations

    """
    def __init__(self, args):
        super(GNN_grover, self).__init__()
        self.num_layer = args.num_layer
        self.drop_ratio = args.dropout_ratio
        self.JK = args.JK

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.x_embedding = torch.nn.Linear(atom_fdim, args.emb_dim, bias=True)

        ###List of MLPs
        self.gnns = torch.nn.ModuleList()
        for layer in range(args.num_layer):
            if args.gnn_type == "gin":
                self.gnns.append(GINConv_grover(args.emb_dim, aggr = "add"))
            elif args.gnn_type == "gcn":
                self.gnns.append(GCNConv(args.emb_dim))
            elif args.gnn_type == "gat":
                self.gnns.append(GATConv(args.emb_dim))
            elif args.gnn_type == "graphsage":
                self.gnns.append(GraphSAGEConv(args.emb_dim))

        ###List of batchnorms
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(args.num_layer):
            self.batch_norms.append(torch.nn.BatchNorm1d(args.emb_dim))

    def forward(self, *argv):
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.x_embedding(x)

        h_list = [x]
        for layer in range(self.num_layer):
            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat   #레이어간 노드 기능들을 어떻게 할건지, 기본은 Last다
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim = 1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]

        return node_representation

In [None]:
class GNN_graphpred_grover(torch.nn.Module):
    """
    Extension of GIN to incorporate edge information by concatenation.

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        num_tasks (int): number of tasks in multi-task learning scenario
        drop_ratio (float): dropout rate
        JK (str): last, concat, max or sum.
        graph_pooling (str): sum, mean, max, attention, set2set
        gnn_type: gin, gcn, graphsage, gat
        
    See https://arxiv.org/abs/1810.00826
    JK-net: https://arxiv.org/abs/1806.03536
    """
    def __init__(self, args):
        super(GNN_graphpred_grover, self).__init__()
        self.num_layer = args.num_layer
        self.drop_ratio = args.dropout_ratio
        self.JK = args.JK
        self.emb_dim = args.emb_dim
        self.num_tasks = args.num_tasks
        self.num_class = args.num_class
        self.gnn_type = args.gnn_type
        self.regression = args.regression

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn = GNN_grover(args)

        #Different kind of graph pooling
        if args.graph_pooling == "sum":
            self.pool = global_add_pool
        elif args.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif args.graph_pooling == "max":
            self.pool = global_max_pool
        elif args.graph_pooling == "attention":
            if self.JK == "concat":
                self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * self.emb_dim, 1))
            else:
                self.pool = GlobalAttention(gate_nn = torch.nn.Linear(self.emb_dim, 1))
        elif args.graph_pooling[:-1] == "set2set":
            set2set_iter = int(graph_pooling[-1])
            if self.JK == "concat":
                self.pool = Set2Set((self.num_layer + 1) * self.emb_dim, set2set_iter)
            else:
                self.pool = Set2Set(self.emb_dim, set2set_iter)
        else:
            raise ValueError("Invalid graph pooling type.")

        #For graph-level binary classification
        if args.graph_pooling[:-1] == "set2set":
            self.mult = 2
        else:
            self.mult = 1
        
        if self.JK == "concat":
            if args.regression : 
                self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, 1)
            elif args.multi_class :
                self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_class)
            else : 
                self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
        else:
            if args.regression :
                self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, 1)
            elif args.multi_class :
                self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_class)
            else : 
                self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)

    def from_pretrained(self, model_file):
        self.gnn.load_state_dict(torch.load(model_file))

    def forward(self, *argv):
        if len(argv) == 4:
            x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        else:
            raise ValueError("unmatched number of arguments.")

        node_representation = self.gnn(x, edge_index, edge_attr)

        output = self.graph_pred_linear(self.pool(node_representation, batch))
        return output

# main

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
# python finetune_grover.py --dataset tg407 --epochs 10 --output_path output_grover/tg407_RS1 --batch_size 96 > g_gin_tg407_RS.txt &

import os, time, random, shutil
import argparse
from argparse import ArgumentParser, Namespace
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.data import DataLoader
from loader import MoleculeDataset_grover, MoleculeDataset_other

from sklearn.metrics import roc_auc_score

from model import GNN_grover, GNN_graphpred_grover
from util import calcul_loss, save_cp, confusion_mat, makedirs, create_logger
from splitters import scaffold_split, random_split

from rdkit import RDLogger
import logging
from logging import Logger

# i don't want see warning of torch dataset
import warnings
warnings.filterwarnings(action='ignore')

## definitions

In [None]:
def train(args, model, device, loader, optimizer):
    model.train()

    loss_sum = 0
    iter_count = 0

    for step, batch in enumerate(loader):
        batch = batch.to(device)
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        if args.multi_class:
            y = batch.y
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)

        #loss matrix after removing null target
        loss = calcul_loss(pred, y, args)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        loss_sum += loss
        iter_count += 1

    torch.cuda.empty_cache()
    return loss_sum / iter_count

In [None]:
def valid(args, model, device, loader):
    model.eval()
    y_true = []
    y_scores = []
    cum_loss = 0

    for step, batch in enumerate(loader):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            
            if args.multi_class:
                y = batch.y
            else : 
                y = batch.y.view(pred.shape).to(torch.float64)
            
            loss = calcul_loss(pred, y, args)
        
        cum_loss += loss
        
        if not args.regression:
            y_true.append(y)
            y_scores.append(pred)

    if not args.regression:
        y_true = torch.cat(y_true, dim = 0).cpu().numpy()
        y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

        roc_list = []
        for i in range(y_true.shape[1]):
            #AUC is only defined when there is at least one positive data.
            if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == -1) > 0:
                is_valid = y_true[:,i]**2 > 0
                roc_list.append(roc_auc_score((y_true[is_valid,i] + 1)/2, y_scores[is_valid,i]))

        if len(roc_list) < y_true.shape[1]:
            print("Some target is missing!")
            print("Missing ratio: %f" %(1 - float(len(roc_list))/y_true.shape[1]))

        torch.cuda.empty_cache()
        return cum_loss, sum(roc_list)/len(roc_list) #y_true.shape[1]
    else : 
        torch.cuda.empty_cache()
        return cum_loss, 0

In [None]:
def test(args, model, device, loader):
    
    model.eval()
    y_true = []
    y_scores = []
    cum_loss = 0

    for step, batch in enumerate(loader):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            
            if args.multi_class:
                y = batch.y
            else : 
                y = batch.y.view(pred.shape).to(torch.float64)
            loss = mean_absolute_error(y.cpu().detach().numpy(), pred.cpu().detach().numpy())

        cum_loss += loss
        
        if not args.regression:
            y_true.append(y)
            y_scores.append(pred)
        

    if not args.regression:
        y_true = torch.cat(y_true, dim = 0).cpu().numpy()
        y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

        auc_list = []
        acc_list = []
        rec_list = []
        prec_list = []
        f1s_list = []
        BA_list = []
        tp_list = []
        fp_list = []
        tn_list = []
        fn_list = []
        for i in range(y_true.shape[1]):
            try : 
                auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = confusion_mat(y_true[:,i], y_scores[:,i])
                auc_list.append(auc)
                acc_list.append(acc)
                rec_list.append(rec)
                prec_list.append(prec)
                f1s_list.append(f1s)
                BA_list.append(BA)
                tp_list.append(tp)
                fp_list.append(fp)
                tn_list.append(tn)
                fn_list.append(fn)
            except : 
                info(f'{i}th task only one class')

        torch.cuda.empty_cache()
        return cum_loss, auc_list, acc_list, rec_list, prec_list, f1s_list, BA_list, tp_list, fp_list, tn_list, fn_list
    else : 
        return cum_loss


In [None]:
def run_training(args: Namespace, logger: Logger = None):
    info = logger.info if logger is not None else print
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        
    #set up dataset
    dataset = MoleculeDataset_grover(args.data_path + args.dataset, dataset=args.dataset)
    args.num_tasks = len(dataset[0]['y'])
    labels = pd.read_csv(args.data_path + args.dataset + '/raw/' + args.dataset + '.csv', header=None)[1][1:]
    unique_labels = np.unique(labels[~labels.isnull()])
    args.num_class = len(unique_labels)
    if args.num_class>2 :
        if args.num_tasks>1 : 
            raise ValueError("this model can't treat multi-task and multi-class")
        else:
            args.multi_class=True
    
    if args.split == "scaffold":
        smiles_list = pd.read_csv(args.data_path + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
        info(f'scaffold_balanced_split')
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(dataset, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
        info("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv(args.data_path + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
        info("random scaffold")
    else:
        raise ValueError("Invalid split option.")
        
    info(f'total_size:{len(dataset)} train_size:{len(train_dataset)} val_size:{len(valid_dataset)} test_size:{len(test_dataset)}')

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

    #set up model
    model = GNN_graphpred_grover(args)
    
    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
    model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
    optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
    info(optimizer)

    best_val_loss = 9999
    best_model_path = os.path.join(args.output_path, str(args.seed))
    for epoch in range(1, args.epochs+1):
        info("====epoch " + str(epoch))
        tst = time.time()
        train_loss = train(args, model, device, train_loader, optimizer)
        tet = time.time() - tst
        info("====Evaluation")
        vst = time.time()
        val_loss, val_auc = valid(args, model, device, val_loader)
        vet = time.time() - vst
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_cp(args, model, path=best_model_path)
        if not args.regression : 
            info(f'train_loss:{train_loss:.4f} val_loss:{val_loss:.4f} val_auc:{val_auc:.4f} t_time:{tet} v_time:{vet}')
        else : 
            info(f'train_loss:{train_loss:.4f} val_loss:{val_loss:.4f} t_time:{tet} v_time:{vet}')
    
    best_state = torch.load(os.path.join(best_model_path,'model.pt'))
    model.load_state_dict(best_state['state_dict'])
    
    if not args.regression:
        test_loss, auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = test(args, model, device, test_loader)
        avg_auc = sum(auc)/num_tasks
        avg_acc = sum(acc)/num_tasks
        avg_rec = sum(rec)/num_tasks
        avg_prec = sum(prec)/num_tasks
        avg_f1s = sum(f1s)/num_tasks
        avg_BA = sum(BA)/num_tasks
        avg_tp = sum(tp)/num_tasks
        avg_fp = sum(fp)/num_tasks
        avg_tn = sum(tn)/num_tasks
        avg_fn = sum(fn)/num_tasks

        info(f'seed:{args.seed} loss:{test_loss} auc:{avg_auc} acc:{avg_acc} rec:{avg_rec} prec:{avg_prec} f1:{avg_f1s} BA:{avg_BA}\ntp:{avg_tp} fp:{avg_fp} fn:{avg_fn} tn:{avg_tn}')
        #delete for memory
        del train_dataset, valid_dataset, test_dataset, train_loader, val_loader, test_loader
        return avg_auc, avg_acc, avg_rec, avg_prec, avg_f1s, avg_BA, avg_tp, avg_fp, avg_tn, avg_fn
    
    else:
        test_loss = test(args, model, device, test_loader)
        
        info(f'seed:{args.seed} test_MSE:{test_loss}')
        del train_dataset, valid_dataset, test_dataset, train_loader, val_loader, test_loader
        return test_loss
    

In [None]:
def cross_validate(args: Namespace, logger: Logger = None):
    info = logger.info if logger is not None else print
    
    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)
    if not args.regression:
        auc_list = []
        acc_list = []
        rec_list = []
        prec_list = []
        f1s_list = []
        BA_list = []
        tp_list = []
        fp_list = []
        tn_list = []
        fn_list = []
        for k in range(3):
            auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = run_training(args)
            auc_list.append(auc)
            acc_list.append(acc)
            rec_list.append(rec)
            prec_list.append(prec)
            f1s_list.append(f1s)
            BA_list.append(BA)
            tp_list.append(tp)
            fp_list.append(fp)
            tn_list.append(tn)
            fn_list.append(fn)
            args.seed += 1
        info(f'all test end')
        info(f'overall test_auc : {np.nanmean(auc_list):.4f}\nstd={np.nanstd(auc_list):.4f}')
        info(f'overall test_accuracy : {np.nanmean(acc_list):.4f}\nstd={np.nanstd(acc_list):.4f}')
        info(f'overall test_recall : {np.nanmean(rec_list):.4f}\nstd={np.nanstd(rec_list):.4f}')
        info(f'overall test_precision : {np.nanmean(prec_list):.4f}\nstd={np.nanstd(prec_list):.4f}')
        info(f'overall test_f1score : {np.nanmean(f1s_list):.4f}\nstd={np.nanstd(f1s_list):.4f}')
        info(f'overall test_Balanced_Accuracy : {np.nanmean(BA_list):.4f}\nstd={np.nanstd(BA_list):.4f}')
        info(f'overall test_tp : {np.nanmean(tp_list):.2f}\nstd={np.nanstd(tp_list):.2f}')
        info(f'overall test_fp : {np.nanmean(fp_list):.2f}\nstd={np.nanstd(fp_list):.2f}')
        info(f'overall test_fn : {np.nanmean(fn_list):.2f}\nstd={np.nanstd(fn_list):.2f}')
        info(f'overall test_tn : {np.nanmean(tn_list):.2f}\nstd={np.nanstd(tn_list):.2f}')
        info(f'\n       (pred)pos    neg(pred)')
        info(f'pos(true)    {tp:.2f}  {fn:.2f}')
        info(f'neg(true)    {fp:.2f}  {tn:.2f}')

        return np.nanmean(auc_list)
    else : 
        mse_list = []
        for k in range(3):
            mse = run_training(args)
            mse_list.append(mse)
            args.seed += 1
        info(f'all test end')
        info(f'overall test_mse : {np.nanmean(mse_list):.4f}\nstd={np.nanstd(mse_list):.4f}')
        return np.nanmean(mse_list)

In [None]:
def random_search(args: Namespace, logger: Logger = None):
    info = logger.info if logger is not None else print
    
    init_seed = args.seed
    save_dir = args.output_path

    #randomize parameter list
    lr_list = [0.0005, 0.00075, 0.001, 0.00125, 0.0015, 0.00175, 0.002]
    dropout_list = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
    gpooling_list = ['mean', 'sum']
    lr_scale_list = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]

    # Run training with different random seeds for each fold
    all_scores = []
    params = []
    for iter_num in range(0, args.n_iters):
        info(f'iter {iter_num}')

        #randomize parameter
        np.random.seed()
        random.seed()
        args.lr = np.random.choice(lr_list, 1)[0]
        args.dropout_ratio = np.random.choice(dropout_list, 1)[0]
        args.graph_pooling = np.random.choice(gpooling_list, 1)[0]
        args.lr_scale = np.random.choice(lr_scale_list, 1)[0]
        params.append(f'\n{iter_num}th search parameter : lr is {args.lr} \n dropout is {args.dropout_ratio} \n batch_size is {args.batch_size}')
        info(params[iter_num])

        args.seed = init_seed                        # if change this, result will be change
        iter_dir = os.path.join(save_dir, f'iter_{iter_num}')
        args.output_path = iter_dir
        makedirs(args.output_path)

        iter_score = cross_validate(args, logger)
        all_scores.append(iter_score)

        if not args.regression:
            if max(all_scores)==iter_score : 
                best_iter = iter_num
                best_score = iter_score
                best_param = params[iter_num]
        else : 
            if min(all_scores)==iter_score : 
                best_iter = iter_num
                best_score = iter_score
                best_param = param[iter_num]

    all_scores = np.array(all_scores)

    # Report scores for each iter
    info(f'\n---- {args.n_iters}-iter random search ----')

    for iter_num, scores in enumerate(all_scores):
        info(params[iter_num])
        info(f'Seed {init_seed} ==> test AUC = {np.nanmean(scores):.6f}\n')

    # Report best model
    info(f'\nbest_iter : {best_iter}\nbest_score is {np.nanmean(best_score)}\nbest_param : {best_param}')

In [None]:
def main():
    parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.0001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr_scale', type=float, default=1,
                        help='relative learning rate for the feature extraction layer (default: 1)')
    parser.add_argument('--decay', type=float, default=1e-7,
                        help='weight decay (default: 0)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
    parser.add_argument('--dropout_ratio', type=float, default=0.2,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--graph_pooling', type=str, default="mean",
                        help='graph level pooling (sum, mean, max, set2set, attention)')
    parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
    parser.add_argument('--gnn_type', type=str, default="gin")
    parser.add_argument('--dataset', type=str, default = 'sider', help='root directory of dataset. For now, only classification.')
    parser.add_argument('--data_path', type=str, default = 'dataset/', help='filename to read the model (if there is any)')
    parser.add_argument('--output_path', type=str, default = 'output', help='output filename')
    parser.add_argument('--seed', type=int, default=0, help = "Seed for splitting the dataset.")
    parser.add_argument('--runseed', type=int, default=0, help = "Seed for minibatch selection, random initialization.")
    parser.add_argument('--split', type = str, default="scaffold", help = "random or scaffold or random_scaffold")
    parser.add_argument('--eval_train', type=int, default = 1, help='evaluating training or not')
    parser.add_argument('--num_workers', type=int, default = 4, help='number of workers for dataset loading')
    # For search
    parser.add_argument('--randomsearch', action='store_true', default=False, help='randomsearch mode')
    #parser.add_argument('--gridsearch', action='store_true', default=False, help='gridsearch mode')
    parser.add_argument('--n_iters', type=int, default=1,
                        help='Number of search')
    parser.add_argument('--grover', action='store_true', default=False, help='use grover feature')
    parser.add_argument('--regression', action='store_true', default=False, help='data is regression')
    parser.add_argument('--multi_class', action='store_true', default=False, help='data is multi_class')
    parser.add_argument('--num_tasks', type=int, default=False, help='use grover feature')
    parser.add_argument('--num_class', type=int, default=False, help='use grover feature')
    
    args = parser.parse_args(['--data_path','dataset_grover/','--dataset','bbbp','--grover', '--gnn_type', 'gcn'])
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)
    
    logger = create_logger(name='train', save_dir=args.output_path, quiet=False)
    if args.randomsearch:
        best_metric = 0
        random_search(args=args, logger=logger)
    else : 
        cross_validate(args=args, logger=logger)

In [None]:
main()

scaffold_balanced_split
total_size:2039 train_size:1631 val_size:204 test_size:204
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
====epoch 1


RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 165 but got size 2 for tensor number 1 in the list.

# 상세 실행

In [1]:
# python finetune_grover.py --dataset tg407 --epochs 10 --output_path output_grover/tg407_RS1 --batch_size 96 > g_gin_tg407_RS.txt &

import os, time, random, shutil, math
import argparse
from argparse import ArgumentParser, Namespace
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.data import DataLoader
from loader import MoleculeDataset_grover, MoleculeDataset_other

from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler

from model import GNN_grover, GNN_graphpred_grover
from util import calcul_loss, save_cp, confusion_mat, makedirs, create_logger
from splitters import scaffold_split, random_split

from rdkit import RDLogger
import logging
from logging import Logger

# i don't want see warning of torch dataset
import warnings
warnings.filterwarnings(action='ignore')

In [2]:
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.0001,
                    help='learning rate (default: 0.001)')
parser.add_argument('--lr_scale', type=float, default=1,
                    help='relative learning rate for the feature extraction layer (default: 1)')
parser.add_argument('--decay', type=float, default=1e-7,
                    help='weight decay (default: 0)')
parser.add_argument('--num_layer', type=int, default=5,
                    help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300,
                    help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.2,
                    help='dropout ratio (default: 0.5)')
parser.add_argument('--graph_pooling', type=str, default="mean",
                    help='graph level pooling (sum, mean, max, set2set, attention)')
parser.add_argument('--JK', type=str, default="last",
                    help='how the node features across layers are combined. last, sum, max or concat')
parser.add_argument('--gnn_type', type=str, default="gin")
parser.add_argument('--dataset', type=str, default = 'sider', help='root directory of dataset. For now, only classification.')
parser.add_argument('--data_path', type=str, default = 'dataset/', help='filename to read the model (if there is any)')
parser.add_argument('--output_path', type=str, default = 'output', help='output filename')
parser.add_argument('--seed', type=int, default=0, help = "Seed for splitting the dataset.")
parser.add_argument('--runseed', type=int, default=0, help = "Seed for minibatch selection, random initialization.")
parser.add_argument('--split', type = str, default="scaffold", help = "random or scaffold or random_scaffold")
parser.add_argument('--eval_train', type=int, default = 1, help='evaluating training or not')
parser.add_argument('--num_workers', type=int, default = 4, help='number of workers for dataset loading')

# For search
parser.add_argument('--randomsearch', action='store_true', default=False, help='randomsearch mode')
#parser.add_argument('--gridsearch', action='store_true', default=False, help='gridsearch mode')
parser.add_argument('--n_iters', type=int, default=1,
                    help='Number of search')
parser.add_argument('--grover', action='store_true', default=False, help='use grover feature')
parser.add_argument('--regression', action='store_true', default=False, help='data is regression')
parser.add_argument('--multi_class', action='store_true', default=False, help='data is multi_class')
parser.add_argument('--num_tasks', type=int, default=1, help='number of tasks')
parser.add_argument('--num_class', type=int, default=2, help='number of class')

# For predict
parser.add_argument('--model_path', type=str, default = 'dataset/', help='filename to read the model (if there is any)')
parser.add_argument('--predict', action='store_true', default=False, help='only predicition')

#args = parser.parse_args(['--split','random','--data_path','dataset_grover/','--dataset','toxcast','--grover','--epochs','2', '--gnn_type','gin', '--output_path', 'output_grover/test'])
args = parser.parse_args(['--data_path', 'dataset_grover/' ,'--epochs', '100', '--batch_size', '96', '--grover', '--dataset', 'qm7', '--output_path', 'output_grover/test', '--gnn_type', 'gcn', '--predict', '--model_path', 'results/gsage_results/gsage_qm7_RS1/iter_8','--regression'])
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

logger = create_logger(name='train', save_dir=args.output_path, quiet=False)

# run_predict

In [3]:
info = logger.info if logger is not None else print

torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

#set up dataset
dataset = MoleculeDataset_grover(args.data_path + args.dataset, dataset=args.dataset)
args.num_tasks = len(dataset[0]['y'])
labels = pd.read_csv(args.data_path + args.dataset + '/raw/' + args.dataset + '.csv', header=None)[1][1:]
try : unique_labels = np.unique(labels[~labels.isnull()])
except : unique_labels = np.unique(labels[~labels.isnull()].astype(float))
args.num_class = len(unique_labels)
if args.num_class>2 :
    if args.num_tasks>1 and not args.regression: 
        raise ValueError("this model can't treat multi-task and multi-class")
    else:
        args.multi_class=True

if args.regression: 
    scaler = StandardScaler()
    scaler.fit(dataset.y.view(-1,1))
else:
    scaler=None

if args.split == "scaffold":
    smiles_list = pd.read_csv(args.data_path + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
    info(f'scaffold_balanced_split')
elif args.split == "random":
    train_dataset, valid_dataset, test_dataset = random_split(dataset, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
    info("random")
elif args.split == "random_scaffold":
    smiles_list = pd.read_csv(args.data_path + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = random_scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
    info("random scaffold")
else:
    raise ValueError("Invalid split option.")

info(f'total_size:{len(dataset)} train_size:{len(train_dataset)} val_size:{len(valid_dataset)} test_size:{len(test_dataset)}')

test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

best_model_path = os.path.join(args.model_path, str(args.seed))
best_state = torch.load(os.path.join(best_model_path,'model.pt'))

#set up model
model = GNN_graphpred_grover(best_state['args'])

model.to(device)

#set up optimizer
#different learning rate for different part of GNN
model_param_group = []
model_param_group.append({"params": model.gnn.parameters()})
if args.graph_pooling == "attention":
    model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
info(optimizer)


model.load_state_dict(best_state['state_dict'])

scaffold_balanced_split
total_size:6830 train_size:5464 val_size:685 test_size:681
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)


<All keys matched successfully>

In [4]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(test_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        if args.multi_class and not args.regression:
            y = batch.y.view(pred.shape[0],1).to(torch.long)
        elif args.regression and args.num_tasks > 1:
            y = batch.y.view(pred.shape).to(torch.float)
            pred = scaler.inverse_transform(pred.cpu())
        elif args.regression:
            y = batch.y.view(pred.shape).to(torch.float)
            pred = scaler.inverse_transform(pred.cpu())
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)

    cum_loss += loss
    y_true.extend(np.array(y.cpu().detach().numpy()))
    y_scores.extend(np.array(pred))

In [16]:
loss

78.96452480328585

In [11]:
if args.dataset=='qm7' or args.dataset=='qm8':
    loss = mean_absolute_error(y_true, y_scores)
    #loss = loss.mean()
elif args.regression:
    loss = math.sqrt(mean_squared_error(y_true.cpu().detach().numpy(), y_scores))
    #loss = np.mean(loss)
else:
    loss = calcul_loss(pred, y, args)

In [None]:
math.sqrt(mean_squared_error(targets, preds))

In [153]:
data = pd.read_csv(args.data_path + args.dataset + '/raw/' + args.dataset + '.csv')
pred_list=np.zeros([data.shape[0],data.shape[1]-1])

In [160]:
pred_list+=np.array(y_scores)

In [161]:
pred_list

array([[ 57.14084835, 108.03748702],
       [ 98.85486648, 131.8904612 ],
       [ 81.2640806 , 109.51091295],
       [ 81.47341476, 105.31523212],
       [ 87.92007038, 126.76156198],
       [ 76.63792481, 112.10301446],
       [  8.66721805,  44.48740677],
       [ 44.80099992,  72.86893065],
       [ 63.23977757, 101.17489118],
       [ 43.66668234,  85.66692466],
       [ 23.83146507,  50.12814708],
       [ 73.32021955, 127.03896234],
       [ 93.14637285, 140.34434595],
       [ 67.31356404,  96.7674099 ],
       [ 29.48518066,  50.24708751],
       [ 93.07601004, 123.17195241],
       [ 10.73766195,  26.92253401],
       [ 71.57157168, 121.17616075],
       [115.07041336, 130.40075252],
       [ 91.51343695, 121.86989965],
       [ 71.31928645,  88.41206066],
       [ 55.84418574,  86.25821381],
       [ 32.47871572,  61.71954971],
       [ 44.06031789,  89.08134523],
       [ 43.03973779,  68.16330154],
       [ 53.81777649,  69.32285887],
       [ 84.47462053, 123.65100659],
 

In [9]:
data = pd.read_csv(args.data_path + args.dataset + '/raw/' + args.dataset + '.csv')
pred_list = np.zeros([data.shape[0],data.shape[1]-1])
for k in range(3):
    info = logger.info if logger is not None else print

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    #set up dataset
    dataset = MoleculeDataset_grover(args.data_path + args.dataset, dataset=args.dataset)
    args.num_tasks = len(dataset[0]['y'])
    labels = pd.read_csv(args.data_path + args.dataset + '/raw/' + args.dataset + '.csv', header=None)[1][1:]
    try : unique_labels = np.unique(labels[~labels.isnull()])
    except : unique_labels = np.unique(labels[~labels.isnull()].astype(float))
    args.num_class = len(unique_labels)
    if args.num_class>2 :
        if args.num_tasks>1 and not args.regression: 
            raise ValueError("this model can't treat multi-task and multi-class")
        else:
            args.multi_class=True

    if args.regression: 
        scaler = StandardScaler()
        scaler.fit(dataset.y.view(-1,1))
    else:
        scaler=None

    info(f'total_size:{len(dataset)}')

    test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

    best_model_path = os.path.join(args.model_path, str(args.seed))
    best_state = torch.load(os.path.join(best_model_path,'model.pt'))

    #set up model
    model = GNN_graphpred_grover(best_state['args'])

    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
    model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
    optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
    info(optimizer)


    model.load_state_dict(best_state['state_dict'])

    model.eval()
    y_true = []
    y_scores = []
    cum_loss = 0

    for step, batch in enumerate(dataset):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

            if args.multi_class and not args.regression:
                y = batch.y.view(pred.shape[0],1).to(torch.long)
            elif args.regression and args.num_tasks > 1:
                y = batch.y.view(pred.shape).to(torch.float)
                pred = scaler.inverse_transform(pred.cpu())
            elif args.regression:
                y = batch.y.view(pred.shape).to(torch.float)
                pred = scaler.inverse_transform(pred.cpu())
            else : 
                y = batch.y.view(pred.shape).to(torch.float64)

            if args.dataset=='qm7' or args.dataset=='qm8' and args.regression:
                loss = mean_absolute_error(y.cpu().detach().numpy(), pred)
                loss = loss.mean()
            elif args.regression:
                loss = math.sqrt(mean_squared_error(y.cpu().detach().numpy(), pred))
                loss = np.mean(loss)
            else:
                loss = calcul_loss(pred, y, args)

        cum_loss += loss

        if args.num_tasks>1:
            y_scores.append(pred.reshape(-1))
        else:
            y_scores.append(pred)
    pred_list += np.array(y_scores)
    args.seed += 1

    del dataset, model

for i in range(len(data.columns)-1):
    data[data.columns[i+1]]=pred_list[:,i]/3
data.to_csv(os.path.join(args.output_path,'predict.csv'),index=False)

total_size:483
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
total_size:483
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
total_size:483
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9

In [26]:
args.output_path

'output_grover/test'

In [117]:
y_scores

[array([27.08547453, 49.69276835]),
 array([55.02218913, 71.96164271]),
 array([39.74188021, 58.61276142]),
 array([44.85122198, 62.44884996]),
 array([45.91978682, 67.07985464]),
 array([33.94706781, 57.96532812]),
 array([ 8.20200527, 18.84516631]),
 array([20.1628386 , 38.83685551]),
 array([31.27429386, 53.36570318]),
 array([20.95259779, 41.17154345]),
 array([10.58626482, 27.14912935]),
 array([41.47350512, 55.45135962]),
 array([56.98622184, 75.57178274]),
 array([31.43158784, 49.97460978]),
 array([16.1136228 , 27.73765938]),
 array([49.88889244, 71.21865871]),
 array([ 3.44072578, 12.54108647]),
 array([38.3340447 , 62.57370468]),
 array([57.23693035, 71.90347965]),
 array([47.72191146, 68.28711833]),
 array([36.25778651, 49.00004111]),
 array([27.38002301, 44.68013214]),
 array([15.31441616, 29.97158539]),
 array([25.17605887, 44.05743356]),
 array([20.01608851, 35.42128209]),
 array([21.78137694, 37.53308912]),
 array([42.61014204, 56.4446082 ]),
 array([42.92344477, 70.1648

In [112]:
for i in range(len(data.columns)-1):
    data[data.columns[i+1]]=np.array(y_scores)[:,0]

In [113]:
data

Unnamed: 0,smiles,MLM,HLM
0,CC(C)Nc1ccnc(N2CCN(Cc3cccs3)C(CCO)C2)n1,27.085475,27.085475
1,COc1cc(=O)n(-c2ccccc2)cc1C(=O)N1CCC2(CC1)OCCO2,55.022189,55.022189
2,Cc1cccc(NC(=N)/N=c2\nc(O)c(Cc3ccccc3)c(C)[nH]2)c1,39.741880,39.741880
3,O=C(c1nc2ncccn2n1)N1CCCn2cc(-c3ccccc3)nc21,44.851222,44.851222
4,CCN1CCN(C(=O)c2cc3c(=O)n4cc(C)ccc4nc3n2C)CC1,45.919787,45.919787
...,...,...,...
478,CCc1noc(CC)c1CC(=O)NCC1(CC)CCCCC1,1.194652,1.194652
479,CC(=O)N1CCC2(CC1)OC(=O)C(C)=C2C(=O)N1CCN(C)CC1,57.291971,57.291971
480,CC(C)NC(=O)CN1C(=O)c2ccccc2N2C(=O)c3ccccc3C12,36.517765,36.517765
481,Cn1cc(Br)c(=O)c(NC(=O)c2ccc(O)cc2F)c1,65.171847,65.171847


# run_training

In [25]:
info = logger.info if logger is not None else print

torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

#set up dataset
dataset = MoleculeDataset_grover(args.data_path + args.dataset, dataset=args.dataset)
args.num_tasks = len(dataset[0]['y'])
info(f'number of task : {args.num_tasks}')
labels = pd.read_csv(args.data_path + args.dataset + '/raw/' + args.dataset + '.csv', header=None)[1][1:]
try : unique_labels = np.unique(labels[~labels.isnull()])
except : unique_labels = np.unique(labels[~labels.isnull()].astype(float))
args.num_class = len(unique_labels)
if args.num_class>2 :
    if args.num_tasks>1 and not args.regression: 
        raise ValueError("this model can't treat multi-task and multi-class")
    else:
        args.multi_class=True

if args.regression: 
    scaler = StandardScaler()
    scaler.fit(dataset.y.view(-1,1))
else:
    scaler=None

if args.split == "scaffold":
    smiles_list = pd.read_csv(args.data_path + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
    info(f'scaffold_balanced_split')
elif args.split == "random":
    train_dataset, valid_dataset, test_dataset = random_split(dataset, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
    info("random")
elif args.split == "random_scaffold":
    smiles_list = pd.read_csv(args.data_path + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = random_scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
    info("random scaffold")
else:
    raise ValueError("Invalid split option.")

info(f'total_size:{len(dataset)} train_size:{len(train_dataset)} val_size:{len(valid_dataset)} test_size:{len(test_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

#set up model
model = GNN_graphpred_grover(args)

model.to(device)

#set up optimizer
#different learning rate for different part of GNN
model_param_group = []
model_param_group.append({"params": model.gnn.parameters()})
if args.graph_pooling == "attention":
    model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
info(optimizer)

number of task : 617
number of task : 617
random
random
total_size:8576 train_size:6860 val_size:857 test_size:859
total_size:8576 train_size:6860 val_size:857 test_size:859
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)


In [30]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(test_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        if args.multi_class and not args.regression:
            y = batch.y.view(pred.shape[0],1).to(torch.long)
        elif args.regression and args.num_tasks > 1:
            y = batch.y.view(pred.shape).to(torch.float)
            pred = scaler.inverse_transform(pred.cpu())
        elif args.regression:
            y = batch.y.view(pred.shape).to(torch.float)
            pred = scaler.inverse_transform(pred.cpu())
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)

        if args.dataset=='qm7' or args.dataset=='qm8':
            loss = mean_absolute_error(y.cpu().detach().numpy(), pred)
            loss = loss.mean()
        elif args.regression:
            loss = math.sqrt(mean_squared_error(y.cpu().detach().numpy(), pred))
            loss = np.mean(loss)
        else:
            loss = calcul_loss(pred, y, args)

    cum_loss += loss

    if not args.regression:
        y_true.append(y)
        y_scores.append(pred)


if not args.regression:
    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

    auc_list = []
    acc_list = []
    rec_list = []
    prec_list = []
    f1s_list = []
    tp_list = []
    fp_list = []
    tn_list = []
    fn_list = []
for i in range(y_true.shape[1]):
    if sum([y>=0 for y in y_true[:,i]])==len(y_true[:,i]) or sum([y<=0 for y in y_true[:,i]])==len(y_true[:,i]):
        pass
    else:
        auc, acc, rec, prec, f1s, tp, fp, tn, fn = confusion_mat(y_true[:,i], y_scores[:,i], args)
        auc_list.append(auc)
        acc_list.append(acc)
        rec_list.append(rec)
        prec_list.append(prec)
        f1s_list.append(f1s)
        tp_list.append(tp)
        fp_list.append(fp)
        tn_list.append(tn)
        fn_list.append(fn)

In [31]:
=

[0.3590433956287615,
 0.5316334355828221,
 0.3952745849297573,
 0.400990099009901,
 0.5379436964504284,
 0.5440278988666085,
 0.48574643660915234,
 0.338,
 0.49903288201160545,
 0.3707482993197279,
 0.4454775993237532,
 0.47111111111111104,
 0.4215686274509804,
 0.5561594202898551,
 0.44893617021276594,
 0.4666666666666667,
 0.42713903743315507,
 0.38191489361702124,
 0.5105882352941177,
 0.3980263157894737,
 0.31699346405228757,
 0.4427609427609428,
 0.5991516436903499,
 0.47245963912630584,
 0.4193548387096774,
 0.598,
 0.4863157894736842,
 0.47660427807486627,
 0.606790799561884,
 0.7701149425287356,
 0.5942857142857143,
 0.6,
 0.5714285714285714,
 0.44642857142857145,
 0.576923076923077,
 0.28846153846153844,
 0.6592592592592593,
 0.3703703703703704,
 0.30769230769230765,
 0.8148148148148149,
 0.2,
 0.7053571428571428,
 0.23333333333333334,
 0.0,
 0.6828636216391318,
 0.5966764942374698,
 0.45707831325301207,
 0.48357348703170033,
 0.11747851002865328,
 0.48419540229885055,
 0.4808

In [10]:
for i in range(len(y_true)):
    if sum(y_true[:,i])==0 :
        print(i)

359
362
390
406
447
473
481


IndexError: index 617 is out of bounds for axis 1 with size 617

In [16]:
loss = mean_absolute_error(y.cpu().detach().numpy(), pred)

In [20]:
y

tensor([[-1335.7300],
        [-1626.8000],
        [-1653.7400],
        [-1917.0000],
        [-1217.2300],
        [-1781.4301],
        [-1633.6801],
        [-1916.7400],
        [-1105.6700],
        [-1351.8700],
        [-1297.0500],
        [-1519.7000],
        [-1246.6899],
        [-1321.3500],
        [-1296.8000],
        [-1460.7300],
        [-1327.6200],
        [-1246.1700],
        [-1422.1300],
        [-1447.4700],
        [-1260.7800],
        [-1250.5000],
        [-1295.8800],
        [-1318.6000],
        [-1254.5500],
        [-1291.9000],
        [-1310.1400],
        [-1370.7200],
        [-1371.6500],
        [-1368.2200],
        [-1320.9800],
        [-1320.2800]], device='cuda:0')

In [4]:
model.train()

loss_sum = 0
iter_count = 0

for step, batch in enumerate(train_loader):
    batch = batch.to(device)
    pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
    if args.multi_class and not args.regression:
        y = batch.y.view(pred.shape[0],1).to(torch.long)
    elif args.regression and args.dataset=='qm8':
        y = batch.y.view(pred.shape).to(torch.float)
        y = scaler.transform(y.cpu().view(-1,1))
        y = torch.tensor(y, dtype=torch.float).view(pred.shape).to(device)
    elif args.regression:
        y = batch.y.view(pred.shape).to(torch.float)
        y = scaler.transform(y.cpu().view(-1,1))
        y = torch.tensor(y, dtype=torch.float).to(device)
    else : 
        y = batch.y.view(pred.shape).to(torch.float64)
    if step==0:break

In [49]:
loss_func1 = nn.CrossEntropyLoss()
loss_func2 = nn.CrossEntropyLoss(reduction='none')
is_valid = y**2 > 0
true = y[is_valid]+1
true = torch.tensor([x-1 if x != 0 else x for x in true]).cuda()
loss = loss_func(pred, true)

In [50]:
loss

tensor(1.6105, device='cuda:0', grad_fn=<NllLossBackward0>)

In [52]:
loss_func2(pred,true).mean()

tensor(1.6105, device='cuda:0', grad_fn=<MeanBackward0>)

In [38]:
soft(pred)[2].argmax()

tensor(3, device='cuda:0')

In [46]:
loss

tensor(1.6105, device='cuda:0', grad_fn=<NllLossBackward0>)

In [47]:
loss_func(soft(pred),true)

tensor(1.3655, device='cuda:0', grad_fn=<NllLossBackward0>)

In [48]:
model

GNN_graphpred_grover(
  (gnn): GNN_grover(
    (x_embedding): Linear(in_features=151, out_features=300, bias=True)
    (gnns): ModuleList(
      (0): GCNConv_grover()
      (1): GCNConv_grover()
      (2): GCNConv_grover()
      (3): GCNConv_grover()
      (4): GCNConv_grover()
    )
    (batch_norms): ModuleList(
      (0): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (graph_pred_linear): Linear(in_features=300, out_features=4, bias=True)
)

In [45]:
soft(soft(soft(pred)))[0]

tensor([0.2548, 0.2468, 0.2473, 0.2511], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [None]:
#loss matrix after removing null target
loss = calcul_loss(pred, y, args)
optimizer.zero_grad()
loss.backward()

optimizer.step()

loss_sum += loss
iter_count += 1


In [None]:
y[0]

tensor([ 0.4531,  0.7428, -0.4257,  0.8987,  0.2446,  0.3713, -0.7473, -0.5184,
         0.3524,  0.4847, -6.9525, -0.4706], device='cuda:0')

In [None]:
scaler.inverse_transform(y[0].cpu().view(-1,1))

array([[ 1.93457307e-01],
       [ 2.40145910e-01],
       [ 5.18318096e-02],
       [ 2.65271931e-01],
       [ 1.59862830e-01],
       [ 1.80278897e-01],
       [ 1.42235136e-06],
       [ 3.68902425e-02],
       [ 1.77234651e-01],
       [ 1.98541910e-01],
       [-9.99999985e-01],
       [ 4.45999974e-02]])

In [None]:
#loss matrix after removing null target
loss = calcul_loss(pred, y, args)
optimizer.zero_grad()
loss.backward()

optimizer.step()

loss_sum += loss
iter_count += 1

RuntimeError: The size of tensor a (32) must match the size of tensor b (384) at non-singleton dimension 0

In [None]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(test_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        
        if args.multi_class and not args.regression:
            y = batch.y.view(pred.shape[0],1).to(torch.long)
        elif args.regression:
            y = batch.y.view(pred.shape).to(torch.float)
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)

        if args.dataset=='qm7' or args.dataset=='qm8':
            loss = mean_absolute_error(y.cpu().detach().numpy(), pred.cpu().detach().numpy())
        else:
            loss = calcul_loss(pred, y, args)
    if step==0:break


In [None]:
with torch.no_grad():
    pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

    if args.multi_class and not args.regression:
        y = batch.y.view(pred.shape[0],1).to(torch.long)
    elif args.regression:
        y = batch.y.view(pred.shape).to(torch.float)
    else : 
        y = batch.y.view(pred.shape).to(torch.float64)

    if args.dataset=='qm7' or args.dataset=='qm8':
        loss = mean_absolute_error(y.cpu().detach().numpy(), pred.cpu().detach().numpy())
    else:
        loss = calcul_loss(pred, y, args)

In [None]:
pred

tensor([[ 7.2334e-05],
        [ 1.0687e-02],
        [ 1.9241e-02],
        [ 1.1172e-03],
        [ 1.4585e-02],
        [ 2.8826e-02],
        [ 1.5862e-02],
        [ 1.1386e-02],
        [ 1.1742e-02],
        [ 2.1130e-02],
        [ 2.2717e-02],
        [ 7.7560e-03],
        [ 2.0837e-02],
        [ 1.3149e-02],
        [ 7.0134e-03],
        [ 1.1621e-02],
        [ 2.3123e-02],
        [ 2.0164e-02],
        [ 2.6623e-02],
        [ 2.1454e-02],
        [ 3.6066e-02],
        [ 1.8459e-02],
        [ 3.2620e-02],
        [ 4.9480e-02],
        [ 4.2613e-02],
        [ 1.1911e-02],
        [ 3.5042e-02],
        [ 6.8747e-03],
        [ 3.8131e-03],
        [-8.3040e-05],
        [ 1.6836e-04],
        [ 1.8702e-02]], device='cuda:0')

In [None]:
torch.tensor(scaler.inverse_transform(pred.cpu()), dtype=torch.float).to(device)

tensor([[-3.0517],
        [-3.0295],
        [-3.0116],
        [-3.0495],
        [-3.0213],
        [-2.9915],
        [-3.0187],
        [-3.0280],
        [-3.0273],
        [-3.0076],
        [-3.0043],
        [-3.0356],
        [-3.0083],
        [-3.0244],
        [-3.0372],
        [-3.0275],
        [-3.0035],
        [-3.0097],
        [-2.9961],
        [-3.0070],
        [-2.9764],
        [-3.0132],
        [-2.9836],
        [-2.9483],
        [-2.9627],
        [-3.0269],
        [-2.9785],
        [-3.0375],
        [-3.0439],
        [-3.0520],
        [-3.0515],
        [-3.0127]], device='cuda:0')

In [None]:
loss

tensor(4.4069, device='cuda:0', dtype=torch.float64)

In [None]:
torch.tensor(y, dtype=torch.float).to(device)

tensor([[ 1.1843],
        [ 0.2114],
        [-1.0566],
        [ 0.0033],
        [ 0.6132],
        [-2.7555],
        [-0.6404],
        [ 1.0826],
        [-2.0972],
        [ 0.1388],
        [-0.3839],
        [ 1.0826],
        [-0.4497],
        [ 0.2502],
        [ 1.8812],
        [-0.2726],
        [ 0.3421],
        [ 0.0614],
        [-1.3035],
        [ 0.1214],
        [-1.9811],
        [ 0.3823],
        [-0.8582],
        [-0.5252],
        [-0.5146],
        [ 0.9132],
        [ 0.6587],
        [ 1.1940],
        [-0.8582],
        [ 1.4408],
        [-0.2982],
        [ 0.5502]], device='cuda:0')

In [None]:
import math
from sklearn.metrics import multilabel_confusion_matrix, accuracy_score, mean_squared_error, roc_auc_score, mean_absolute_error, r2_score, precision_recall_curve, auc, recall_score, confusion_matrix, f1_score, precision_score, classification_report

In [None]:
test_loss = test(args, model, device, test_loader)

In [None]:
test_loss

15.05982248096842

In [None]:
batch.y.to(torch.float32).dtype

torch.float32

In [None]:
batch.y.view(pred.shape)

tensor([[-1.9470],
        [-3.7300],
        [-3.8100],
        [-2.7400],
        [-1.2280],
        [-4.8100],
        [-2.3400],
        [-3.4300],
        [-1.4900],
        [-0.4700],
        [-4.4800],
        [-4.1600],
        [-3.2800],
        [-3.5900],
        [-7.0000],
        [-3.9240],
        [-1.9200],
        [-2.6470],
        [-7.2000],
        [-2.6100],
        [-6.1440],
        [-0.8760],
        [-0.9200],
        [-0.7300],
        [-0.2800],
        [-2.7300],
        [-3.5100],
        [-0.6000],
        [-2.1800],
        [-6.9800],
        [-4.6900],
        [-4.4600]], device='cuda:0', dtype=torch.float64)

In [None]:
pred_loss = nn.MSELoss(reduction='none')

In [None]:
loss = pred_loss(pred,y)
loss

tensor([[5.9198e+00],
        [1.1617e+01],
        [2.1863e-02],
        [2.8816e+01],
        [1.6508e+01],
        [7.6355e-01],
        [5.8748e+01],
        [2.0256e+00],
        [4.6746e-01],
        [9.1268e+00],
        [2.0356e+00],
        [4.9889e+01],
        [3.7043e+00],
        [5.3669e+00],
        [1.5966e+01],
        [4.8138e+00],
        [1.2043e+00],
        [1.2110e+01],
        [5.2670e+00],
        [1.2118e+01],
        [1.1560e-01],
        [1.3402e-02],
        [4.2507e+00],
        [1.9482e+01],
        [3.2265e+01],
        [1.3190e+01],
        [2.7398e-01],
        [5.9838e+01],
        [1.0240e+01],
        [3.3492e+00],
        [9.5198e-02],
        [3.4239e+00]], device='cuda:0', grad_fn=<MseLossBackward0>)

In [None]:
loss.dtype

torch.float64

In [None]:
loss.mean().backward()

In [None]:
loss.dtype

torch.float64

In [None]:
loss.to(torch.float).dtype

torch.float32

In [None]:
loss1 = math.sqrt(mean_squared_error(y.cpu().detach().numpy(), pred.cpu().detach().numpy()))
loss2 = torch.tensor(loss1, dtype=float, requires_grad=True).cuda()

In [None]:
loss2.item()

3.727366979585915

In [None]:
model.train()

loss_sum = 0
iter_count = 0

for step, batch in enumerate(train_loader):
    batch = batch.to(device)
    pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
    if args.multi_class and not args.regression:
        y = batch.y.view(pred.shape[0],1).to(torch.long)
    else : 
        y = batch.y.view(pred.shape).to(torch.float64)


    #loss matrix after removing null target
    loss = calcul_loss(pred, y, args)

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

    loss_sum += loss
    iter_count += 1

RuntimeError: shape '[-1, 165, 300]' is invalid for input of size 716700

In [None]:
class GCNConv(MessagePassing):

    def __init__(self, emb_dim, aggr = "add"):
        super(GCNConv, self).__init__()

        self.emb_dim = emb_dim
        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

        self.aggr = aggr

    def norm(self, edge_index, num_nodes, dtype):
        ### assuming that self-loops have been already added in edge_index
        edge_index = edge_index[0]
        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]


    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])

        norm = self.norm(edge_index, x.size(0), x.dtype)

        x = self.linear(x)

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * (x_j + edge_attr)


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
batch.x, batch.edge_index, batch.edge_attr, batch.batch

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[  0,   1,   1,  ..., 289, 287, 290],
         [  1,   0,   2,  ..., 287, 290, 287]], device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,
          2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,
          4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,
          6,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
     

In [None]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros

num_atom_type = 120 #including the extra mask tokens
num_chirality_tag = 3

num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 3 

#for grover feature size
atom_fdim = 151
bond_fdim = 165
class GCNConv_grover(MessagePassing):

    def __init__(self, emb_dim, aggr = "add"):
        super(GCNConv_grover, self).__init__()

        self.emb_dim = emb_dim
        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.edge_embedding = torch.nn.Linear(bond_fdim, emb_dim)

        self.aggr = aggr

    def norm(self, edge_index, num_nodes, dtype):
        ### assuming that self-loops have been already added in edge_index
        edge_index = edge_index[0]
        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]


    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), bond_fdim)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding(edge_attr)

        norm = self.norm(edge_index, x.size(0), x.dtype)

        x = self.linear(x)

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * (x_j + edge_attr)

In [None]:
class GNN_grover(torch.nn.Module):
    """
    

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        JK (str): last, concat, max or sum.
        max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat

    Output:
        node representations

    """
    def __init__(self, args):
        super(GNN_grover, self).__init__()
        self.num_layer = args.num_layer
        self.drop_ratio = args.dropout_ratio
        self.JK = args.JK

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.x_embedding = torch.nn.Linear(atom_fdim, args.emb_dim, bias=True)

        ###List of MLPs
        self.gnns = torch.nn.ModuleList()
        for layer in range(args.num_layer):
            if args.gnn_type == "gin":
                self.gnns.append(GINConv_grover(args.emb_dim, aggr = "add"))
            elif args.gnn_type == "gcn":
                self.gnns.append(GCNConv_grover(args.emb_dim))
            elif args.gnn_type == "gat":
                self.gnns.append(GATConv_grover(args.emb_dim))
            elif args.gnn_type == "graphsage":
                self.gnns.append(GraphSAGEConv_grover(args.emb_dim))

        ###List of batchnorms
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(args.num_layer):
            self.batch_norms.append(torch.nn.BatchNorm1d(args.emb_dim))

    def forward(self, *argv):
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.x_embedding(x)

        h_list = [x]
        for layer in range(self.num_layer):
            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat   #레이어간 노드 기능들을 어떻게 할건지, 기본은 Last다
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim = 1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]

        return node_representation

In [None]:
gnn = GNN_grover(args).cuda()

In [None]:
gnn(batch.x, batch.edge_index, batch.edge_attr)

tensor([[ 0.0000,  0.2067, -0.0000,  ..., -1.5656, -1.2953, -1.1305],
        [ 0.5938, -0.0312, -1.9180,  ..., -1.6548, -1.3092, -1.6185],
        [ 0.0000, -0.1467, -0.0000,  ..., -0.6165, -0.6392, -0.7175],
        ...,
        [ 0.3571, -0.7075, -2.3615,  ...,  0.9665, -0.5579, -0.5173],
        [ 0.3602, -1.0445, -0.0000,  ...,  0.2776, -0.1047, -0.8614],
        [ 0.2864, -1.1540, -2.5878,  ..., -0.8184, -0.0000,  0.1091]],
       device='cuda:0', grad_fn=<NativeDropoutBackward0>)

In [None]:
gnn

GNN_grover(
  (x_embedding): Linear(in_features=151, out_features=300, bias=True)
  (gnns): ModuleList(
    (0): GCNConv_grover()
    (1): GCNConv_grover()
    (2): GCNConv_grover()
    (3): GCNConv_grover()
    (4): GCNConv_grover()
  )
  (batch_norms): ModuleList(
    (0): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [None]:
x_embedding = torch.nn.Linear(151, 300, bias=True).cuda()

In [None]:
x = x_embedding(batch.x)

In [None]:
h_list = [x]
h_list

[tensor([[ 0.0631,  0.1483, -0.0175,  ...,  0.0675,  0.1497,  0.1085],
         [-0.1237,  0.1433, -0.0254,  ..., -0.0213,  0.0557, -0.1806],
         [-0.3229,  0.1615, -0.0075,  ...,  0.0665, -0.0487, -0.0212],
         ...,
         [ 0.1647,  0.1275, -0.0356,  ...,  0.2162,  0.0066, -0.2511],
         [ 0.1647,  0.1275, -0.0356,  ...,  0.2162,  0.0066, -0.2511],
         [ 0.1647,  0.1275, -0.0356,  ...,  0.2162,  0.0066, -0.2511]],
        device='cuda:0', grad_fn=<AddmmBackward0>)]

In [None]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros

num_atom_type = 120 #including the extra mask tokens
num_chirality_tag = 3

num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 3 

#for grover feature size
atom_fdim = 151
bond_fdim = 165
class GCNConv_grover(MessagePassing):

    def __init__(self, emb_dim, aggr = "add"):
        super(GCNConv_grover, self).__init__()

        self.emb_dim = emb_dim
        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.edge_embedding = torch.nn.Linear(bond_fdim, emb_dim)

        self.aggr = aggr

    def norm(self, edge_index, num_nodes, dtype):
        ### assuming that self-loops have been already added in edge_index
        edge_index = edge_index[0]
        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]


    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding(edge_attr)

        norm = self.norm(edge_index, x.size(0), x.dtype)

        x = self.linear(x)

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * (x_j + edge_attr)

In [None]:
gnns = torch.nn.ModuleList()
for layer in range(5):
    gnns.append(GCNConv_grover(300)).cuda()

In [None]:
for layer in range(5):
    h = gnns[layer](h_list[layer], batch.edge_index, batch.edge_attr)

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 165 but got size 2 for tensor number 1 in the list.

In [None]:
gnns[0](h_list[0], batch.edge_index, batch.edge_attr)

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 165 but got size 2 for tensor number 1 in the list.

In [None]:
h_list[0].shape

torch.Size([291, 300])

In [None]:
batch.edge_index.shape

torch.Size([2, 550])

In [None]:
batch.edge_attr.shape

torch.Size([550, 165])

In [None]:
from typing import List, Callable, Union

In [None]:
from sklearn.metrics import multilabel_confusion_matrix, accuracy_score, mean_squared_error, roc_auc_score, mean_absolute_error, r2_score, \
    precision_recall_curve, auc, recall_score, confusion_matrix, f1_score, precision_score, classification_report

In [None]:
def confusion_mat(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
    """
    Computes the specificity of a binary prediction task using a given threshold for generating hard predictions.

    :param targets: A list of binary targets.
    :param preds: A list of prediction probabilities.
    :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
    :return: The computed specificity.
    """
    if args.multi_class:
        is_valid = targets**2 > 0
        targets = targets[is_valid]+1
        targets = torch.tensor([x-1 if x != 0 else x for x in targets]).cuda()
        preds = torch.softmax(torch.tensor(preds),dim=1)
        hard_preds = [1 if p > threshold else 0 for p in preds]
        auc = roc_auc_score(targets, preds)
        tn, fp, fn, tp = multilabel_confusion_matrix(targets, hard_preds).ravel()
        acc = accuracy_score(targets, hard_preds)
        rec = recall_score(targets, hard_preds)
        prec = precision_score(targets, hard_preds)
        spe = tn / float(tn + fp)
        f1s = f1_score(targets, hard_preds)
        BA = (rec+spe)/2
    else : 
        valid = targets**2>0
        targets = (targets[valid]+1)/2
        preds = preds[valid]
        #hard_preds = [1 if p > threshold else 0 for p in preds]
        auc = roc_auc_score(targets, preds)
        tn, fp, fn, tp = confusion_matrix(targets, hard_preds).ravel()
        acc = accuracy_score(targets, hard_preds)
        rec = recall_score(targets, hard_preds)
        prec = precision_score(targets, hard_preds)
        spe = tn / float(tn + fp)
        f1s = f1_score(targets, hard_preds)
        BA = (rec+spe)/2
    return auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn

In [None]:
y_true.shape

(103, 1)

In [None]:
y_scores.shape

(103, 3)

In [None]:
targets = y_true
preds = y_scores
threshold=0.5

In [None]:
targets.shape

torch.Size([83])

In [None]:
preds.shape

torch.Size([103, 3])

In [None]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(test_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        if args.multi_class:
            y = batch.y.view(pred.shape[0],1).to(torch.long)
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)

        loss = calcul_loss(pred, y, args)

    cum_loss += loss

    if not args.regression:
        y_true.append(y)
        y_scores.append(pred)


if not args.regression:
    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

targets = y_true
preds = y_scores
threshold=0.5
    
is_valid = targets**2 > 0
targets = targets[is_valid]+1
targets = torch.tensor([x-1 if x != 0 else x for x in targets])
preds = torch.softmax(torch.tensor(preds),dim=1).cpu()
hard_preds = torch.max(preds,1).indices
auc = roc_auc_score(targets, preds, multi_class='ovr')
cm = multilabel_confusion_matrix(targets, hard_preds).ravel()
acc = accuracy_score(targets, hard_preds)
rec = recall_score(targets, hard_preds, average='micro')
prec = precision_score(targets, hard_preds, average='micro')
f1s = f1_score(targets, hard_preds, average='micro')

In [None]:
auc = roc_auc_score(targets, preds, multi_class='ovr')
cm = multilabel_confusion_matrix(targets, hard_preds).ravel()
acc = accuracy_score(targets, hard_preds)
rec = recall_score(targets, hard_preds, average='micro')
prec = precision_score(targets, hard_preds, average='micro')
f1s = f1_score(targets, hard_preds, average='micro')

In [None]:
len(cm)/4

3.0

In [None]:
tp=[]
fp=[]
tn=[]
fn=[]
for i in range(int(len(cm)/4)):
    tp.append(cm[4*i])
    fp.append(cm[4*i+1])
    fn.append(cm[4*i+2])
    tn.append(cm[4*i+3])

In [None]:
tp

[46, 23, 64]

In [None]:
args

Namespace(JK='last', batch_size=32, data_path='dataset_grover/', dataset='tg407', decay=1e-07, device=0, dropout_ratio=0.2, emb_dim=300, epochs=2, eval_train=1, gnn_type='gin', graph_pooling='mean', grover=True, lr=0.0001, lr_scale=1, multi_class=True, n_iters=1, num_class=3, num_layer=5, num_tasks=1, num_workers=4, output_path='output', randomsearch=False, regression=False, runseed=0, seed=0, split='scaffold')

In [None]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(test_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        if args.multi_class:
            y = batch.y.view(pred.shape[0],1).to(torch.long)
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)

        loss = calcul_loss(pred, y, args)

    cum_loss += loss

    if not args.regression:
        y_true.append(y)
        y_scores.append(pred)


if not args.regression:
    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

    auc_list = []
    acc_list = []
    rec_list = []
    prec_list = []
    f1s_list = []
    BA_list = []
    tp_list = []
    fp_list = []
    tn_list = []
    
    if args.multi_class:
        auc, acc, rec, prec, f1s, tp, fp, tn, fn = confusion_mat(y_true, y_scores)
    else : 
        for i in range(y_true.shape[1]):
            auc, acc, rec, prec, f1s, tp, fp, tn, fn = confusion_mat(y_true[:,i], y_scores[:,i])

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [None]:
score = torch.softmax(torch.tensor(y_scores), dim=1)

In [None]:
roc_auc_score(y_true,score, multi_class='ovr')

0.446751697537255

In [None]:
y_scores.shape

(103, 3)

In [None]:
y_true.shape

(103, 1)

In [None]:
roc_list = []
for i in range(y_true.shape[1]):
    #AUC is only defined when there is at least one positive data.
    is_valid = y_true**2 > 0
    y_true = y_true[is_valid]+1
    y_true = torch.tensor([x-1 if x != 0 else x for x in y_true]).view(-1,1)
#    roc_score = roc_auc_score(y_true, y_scores, multi_class='ovr')

In [None]:
new_score = (y_scores - y_scores.mean())/y_scores.std()

In [None]:
sum(y_scores)

array([-1.2087897, -1.8467627, -9.9976425], dtype=float32)

In [None]:
roc_auc_score(y_true, torch.softmax(torch.tensor(y_scores), dim=1), multi_class='ovr')

0.446751697537255

In [None]:
loss_func = nn.CrossEntropyLoss()
is_valid = true**2 > 0
true = true[is_valid]+1
true = torch.tensor([x-1 if x != 0 else x for x in true]).cuda()

In [None]:
pred.dtype

torch.float32

In [None]:
true.dtype

torch.int64

In [None]:
loss = loss_func(pred, true)

In [None]:
loss_func = nn.CrossEntropyLoss()

In [None]:
loss_func_bce = nn.BCEWithLogitsLoss(reduction = "none")

In [None]:
pred.shape

torch.Size([32, 1])

In [None]:
true.shape

torch.Size([32, 1])

In [None]:
(true.to(int)+1)/2

tensor([[0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.]], device='cuda:0')

In [None]:
loss_func(pred.double(), (true.to(int)+1)/2)

tensor(-0., device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward1>)

In [None]:
loss = calcul_loss(pred, y, args)
loss

tensor(0., device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>)

In [None]:
args.num_tasks

1

In [None]:
batch.y.shape

torch.Size([32])

In [None]:
print(pred.shape)

torch.Size([32, 3])


In [None]:
true = y
is_valid = true**2 > 0
#Loss matrix
true = true[is_valid]+1
true = torch.tensor([x-1 if x != 0 else x for x in true]).cuda()
loss = loss_func(pred, true)

In [None]:
loss

tensor(1.2007, device='cuda:0', grad_fn=<NllLossBackward0>)

In [None]:
true.T

tensor([1, 1, 1, 0, 0, 2, 1, 2, 2, 1, 1, 2, 2, 0, 1, 1, 2, 2, 2, 2, 2, 0, 0, 2,
        0, 1, 2, 1, 1, 2, 1, 2], device='cuda:0')

In [None]:
def calcul_loss(pred, true, args):
    if args.regression:
        loss_func = nn.MSELoss(reduction='none')
        loss = loss_func(pred, true).mean()
        loss = torch.tensor(loss, dtype=float, requires_grad=True)
        return loss
    else : 
        loss_func = nn.CrossEntropyLoss()
        is_valid = true**2 > 0
        #Loss matrix
        true = true[is_valid]+1
        true = torch.tensor([x-1 if x != 0 else x for x in true]).cuda()
        loss = loss_func(pred.T, true.T)
        return loss

In [None]:
calcul_loss(pred,y,args)

ValueError: Expected input batch_size (1) to match target batch_size (32).

In [None]:
loss_func(pred.double(), (y+1)/2)

tensor(-0., device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward1>)

In [None]:
y = batch.y

In [None]:
true=y#.to(torch.float64)

In [None]:
loss_func = nn.CrossEntropyLoss()
is_valid = true**2 > 0
#Loss matrix
true = true[is_valid]+1
true = torch.tensor([x-1 if x != 0 else x for x in true]).cuda()

In [None]:
pred

tensor([[-0.2155],
        [-0.4307],
        [-0.0201],
        [ 0.2719],
        [-0.2402],
        [ 0.3789],
        [ 0.2305],
        [-0.0304],
        [ 0.1599],
        [ 0.1946],
        [ 0.0443],
        [-0.1231],
        [-0.2343],
        [ 0.2610],
        [-0.1922],
        [-0.3010],
        [ 0.2754],
        [ 0.0177],
        [ 0.1343],
        [-0.6307],
        [-0.0493],
        [ 0.0009],
        [ 0.1362],
        [-0.0248],
        [-0.2006],
        [-0.1762],
        [ 0.0333],
        [-0.0517],
        [-0.3256],
        [-0.0681],
        [-0.1654],
        [-0.1134]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [None]:
true

tensor([1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1,
        0, 0, 0, 0, 1, 0, 0, 0], device='cuda:0')

In [None]:
def calcul_loss(pred, true, args):
    if args.regression:
        loss_func = nn.MSELoss(reduction='none')
        loss = loss_func(pred, true).mean()
        loss = torch.tensor(loss, dtype=float, requires_grad=True)
        return loss
    else : 
        loss_func = nn.CrossEntropyLoss()
        
        if args.multi_class : loss = loss_func(pred, true)
        else : loss = loss_func(pred.T[0], true)
        
        return loss

In [None]:
confusion_mat(y_true[:,i], y_scores[:,i])

ValueError: continuous format is not supported

In [None]:
model.train()

loss_sum = 0
iter_count = 0

for step, batch in enumerate(train_loader):
    batch = batch.to(device)
    pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
    if args.multi_class:
        y = batch.y
    else : 
        y = batch.y.view(pred.shape).to(torch.float64)

    #loss matrix after removing null target
    loss = calcul_loss(pred, y, args)

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

    loss_sum += loss
    iter_count += 1

torch.cuda.empty_cache()

RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Double'

In [None]:
np.sum(y_true[:,i] == -1) > 0

False

## valid

In [None]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(val_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        if args.multi_class:
            y = batch.y.unsqueeze(dim=-1)
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)
        is_valid = y**2 > 0
        y = y[is_valid]+1
        y = torch.tensor([x-1 if x != 0 else x for x in y]).cuda()
        

        loss = calcul_loss(pred, y, args)

    cum_loss += loss

    if not args.regression:
        y_true.append(y)
        y_scores.append(pred)

if not args.regression:
    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()
    
    if args.multi_class : y_true = y_true[:,np.newaxis] 
    
    roc_list = []
    for i in range(y_true.shape[1]):
        #AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == 0) > 0:
            is_valid = y_true[:,i]**2 > 0
            roc_list.append(roc_auc_score(y_true[is_valid,i], y_scores[is_valid,i]))

    if len(roc_list) < y_true.shape[1]:
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list))/y_true.shape[1]))

    torch.cuda.empty_cache()

## test

In [None]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(test_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        if args.multi_class:
            y = batch.y
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)
        is_valid = y**2 > 0
        y = y[is_valid]+1
        y = torch.tensor([x-1 if x != 0 else x for x in y]).cuda()

        loss = calcul_loss(pred, y, args)

    cum_loss += loss

    if not args.regression:
        y_true.append(y)
        y_scores.append(pred)


if not args.regression:
    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

    auc_list = []
    acc_list = []
    rec_list = []
    prec_list = []
    f1s_list = []
    BA_list = []
    tp_list = []
    fp_list = []
    tn_list = []
    fn_list = []

    if args.multi_class : y_true = y_true[:,np.newaxis] 

    for i in range(y_true.shape[1]):
        auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = confusion_mat(y_true[:,i], y_scores[:,i])

ValueError: continuous format is not supported

In [None]:
y_scores[:,0].shape

(103,)

In [None]:
y_true[:,0].shape

(103,)

In [None]:
info = logger.info if logger is not None else print

if not os.path.exists(args.output_path):
    os.makedirs(args.output_path)
if not args.regression:
    auc_list = []
    acc_list = []
    rec_list = []
    prec_list = []
    f1s_list = []
    BA_list = []
    tp_list = []
    fp_list = []
    tn_list = []
    fn_list = []
    for k in range(3):
        auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = run_training(args)
        auc_list.append(auc)
        acc_list.append(acc)
        rec_list.append(rec)
        prec_list.append(prec)
        f1s_list.append(f1s)
        BA_list.append(BA)
        tp_list.append(tp)
        fp_list.append(fp)
        tn_list.append(tn)
        fn_list.append(fn)
        args.seed += 1
    info(f'all test end')
    info(f'overall test_auc : {np.nanmean(auc_list):.4f}\nstd={np.nanstd(auc_list):.4f}')
    info(f'overall test_accuracy : {np.nanmean(acc_list):.4f}\nstd={np.nanstd(acc_list):.4f}')
    info(f'overall test_recall : {np.nanmean(rec_list):.4f}\nstd={np.nanstd(rec_list):.4f}')
    info(f'overall test_precision : {np.nanmean(prec_list):.4f}\nstd={np.nanstd(prec_list):.4f}')
    info(f'overall test_f1score : {np.nanmean(f1s_list):.4f}\nstd={np.nanstd(f1s_list):.4f}')
    info(f'overall test_Balanced_Accuracy : {np.nanmean(BA_list):.4f}\nstd={np.nanstd(BA_list):.4f}')
    info(f'overall test_tp : {np.nanmean(tp_list):.2f}\nstd={np.nanstd(tp_list):.2f}')
    info(f'overall test_fp : {np.nanmean(fp_list):.2f}\nstd={np.nanstd(fp_list):.2f}')
    info(f'overall test_fn : {np.nanmean(fn_list):.2f}\nstd={np.nanstd(fn_list):.2f}')
    info(f'overall test_tn : {np.nanmean(tn_list):.2f}\nstd={np.nanstd(tn_list):.2f}')
    info(f'\n       (pred)pos    neg(pred)')
    info(f'pos(true)    {tp:.2f}  {fn:.2f}')
    info(f'neg(true)    {fp:.2f}  {tn:.2f}')

    #return np.nanmean(auc_list)
else : 
    mse_list = []
    for k in range(3):
        mse = run_training(args)
        mse_list.append(mse)
        args.seed += 1
    info(f'all test end')
    info(f'overall test_mse : {np.nanmean(mse_list):.4f}\nstd={np.nanstd(mse_list):.4f}')
    #return np.nanmean(mse_list)

scaffold_balanced_split
total_size:642 train_size:513 val_size:64 test_size:65
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
====epoch 1
====Evaluation
train_loss:21.4141 val_loss:151.4233 t_time:0.5759155750274658 v_time:0.3699347972869873
====epoch 2
====Evaluation
train_loss:20.6982 val_loss:155.1019 t_time:0.5931878089904785 v_time:0.3742382526397705
seed:0 test_MSE:105.32380150194973
scaffold_balanced_split
total_size:642 train_size:513 val_size:64 test_size:65
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: F

all test end
all test end
all test end


seed:2 test_MSE:93.06369794457649


AttributeError: 'torch.dtype' object has no attribute 'type'

In [None]:
np.nanstd(torch.tensor(mse_list, dtype=float))

5.010453576621481

In [None]:
torch.max(pred, 1)

torch.return_types.max(
values=tensor([ 0.2405,  0.7774,  0.9729,  0.4813,  0.2257,  1.0280,  0.6748,  0.7067,
        -0.0335,  0.6781,  0.5633,  0.6351,  0.7562,  0.1922, -0.0019, -0.3982,
         0.5939,  0.5229,  0.5480,  0.7797,  0.1883, -0.1349], device='cuda:0',
       grad_fn=<MaxBackward0>),
indices=tensor([2, 0, 1, 1, 1, 1, 1, 1, 2, 0, 0, 1, 0, 1, 2, 1, 1, 0, 0, 2, 2, 1],
       device='cuda:0'))

In [None]:
batch.y

tensor([2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2],
       device='cuda:0')

In [None]:
y = batch.y[batch.y**2>0]+1
y

tensor([3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3, 2, 3, 2, 2, 3, 3, 2, 3, 3, 2, 3],
       device='cuda:0')

In [None]:
y2 = torch.tensor([x-1 if x != 0 else x for x in y]).cuda()
y2

tensor([2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2],
       device='cuda:0')

In [None]:
loss_func(pred.view(1,32), y)

tensor([[19.1630,  5.4770,  0.4633,  4.2864,  3.1695,  0.8794, 13.4749,  4.9141,
          9.5742,  3.6060, 15.7390,  5.7975,  3.6819, 11.1252,  1.1287,  4.4419,
          8.3167,  1.4711,  0.7695,  9.1187,  0.6961,  7.7937,  5.0153, 23.4030,
          2.7433,  5.3533,  1.1984,  2.4379, 19.5267, 15.2287,  0.2520, 12.8192]],
       device='cuda:0', dtype=torch.float64, grad_fn=<MseLossBackward0>)

In [None]:
loss_func(pred, y.view(32,1)).mean()

tensor(6.9708, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

In [None]:
def calcul_loss(pred, true, args):
    if args.regression:
        loss_func = nn.MSELoss(reduction='none')
    else : 
        loss_func = nn.CrossEntropyLoss()
    #Whether y is non-null or not.
    is_valid = true**2 > 0
    #Loss matrix
    true = true[is_valid]+1
    true = torch.tensor([x-1 if x != 0 else x for x in true]).cuda()
    loss_mat = loss_func(pred.double(), true)
    #loss matrix after removing null target
    loss_mat = torch.where(is_valid, loss_mat, torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype))
    loss = torch.sum(loss_mat)/torch.sum(is_valid)
    return loss

In [None]:
model.train()

loss_sum = 0
iter_count = 0

for step, batch in enumerate(train_loader):
    batch = batch.to(device)
    pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
    if args.multi_class:
        y = batch.y
    else : 
        y = batch.y.view(pred.shape).to(torch.float64)

    #loss matrix after removing null target
    loss = calcul_loss(pred, y, args)
    

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

    loss_sum += loss
    iter_count += 1

torch.cuda.empty_cache()
print(loss_sum / iter_count)

tensor(1.0508, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>)


In [None]:
pred

tensor([[-0.0192, -0.0096, -0.1274],
        [-0.0130, -0.0034, -0.1019],
        [-0.0126, -0.0118, -0.1022],
        [-0.0163, -0.0278, -0.0550],
        [ 0.0024, -0.0260, -0.0689],
        [-0.0132,  0.0036, -0.1478],
        [-0.0029, -0.0643, -0.0745]], device='cuda:0')

In [None]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(val_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        if args.multi_class:
            y = batch.y.view([len(batch.y),1])
        else : 
            y = batch.y.view(pred.shape).to(torch.float64)

        loss = calcul_loss(pred, y, args)

    cum_loss += loss

    if not args.regression:
        y_true.append(y)
        y_scores.append(pred)

if not args.regression:
    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

    roc_list = []
    for i in range(y_true.shape[1]):
        #AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == -1) > 0:
            is_valid = y_true[:,i]**2 > 0
            if args.multi_class:
                roc_list.append(roc_auc_score(y_true, y_scores, multi_class='ovr'))
            else : 
                roc_list.append(roc_auc_score((y_true[is_valid,i] + 1)/2, y_scores[is_valid,i]))

    if len(roc_list) < y_true.shape[1]:
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list))/y_true.shape[1]))

    torch.cuda.empty_cache()
    #return cum_loss, sum(roc_list)/len(roc_list) #y_true.shape[1]
else : 
    torch.cuda.empty_cache()
    #return cum_loss, 0

ValueError: Target scores need to be probabilities for multiclass roc_auc, i.e. they should sum up to 1.0 over classes

In [None]:
best_val_loss = 9999
best_model_path = os.path.join(args.output_path, str(args.seed))
for epoch in range(1, args.epochs+1):
    info("====epoch " + str(epoch))
    tst = time.time()
    train_loss = train(args, model, device, train_loader, optimizer)
    tet = time.time() - tst
    info("====Evaluation")
    vst = time.time()
    val_loss, val_auc = valid(args, model, device, val_loader)
    vet = time.time() - vst
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_cp(args, model, path=best_model_path)
    info(f'train_loss:{train_loss:.4f} val_loss:{val_loss:.4f} val_auc:{val_auc:.4f} t_time:{tet} v_time:{vet}')

best_state = torch.load(os.path.join(best_model_path,'model.pt'))
model.load_state_dict(best_state['state_dict'])

if not args.regression:
    test_loss, auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = test(args, model, device, test_loader)
    avg_auc = sum(auc)/num_tasks
    avg_acc = sum(acc)/num_tasks
    avg_rec = sum(rec)/num_tasks
    avg_prec = sum(prec)/num_tasks
    avg_f1s = sum(f1s)/num_tasks
    avg_BA = sum(BA)/num_tasks
    avg_tp = sum(tp)/num_tasks
    avg_fp = sum(fp)/num_tasks
    avg_tn = sum(tn)/num_tasks
    avg_fn = sum(fn)/num_tasks
else : 
    test_loss = test(args, model, device, test_loader)

del train_dataset, valid_dataset, test_dataset, train_loader, val_loader, test_loader

if not args.regression:
    info(f'seed:{args.seed} loss:{test_loss} auc:{avg_auc} acc:{avg_acc} rec:{avg_rec} prec:{avg_prec} f1:{avg_f1s} BA:{avg_BA}\ntp:{avg_tp} fp:{avg_fp} fn:{avg_fn} tn:{avg_tn}')
    return avg_auc, avg_acc, avg_rec, avg_prec, avg_f1s, avg_BA, avg_tp, avg_fp, avg_tn, avg_fn
else : 
    info(f'seed:{args.seed} test_MSE:{test_loss}')
    return test_loss
#delete for memory

In [None]:
import torch
import torch.nn as nn

class MulticlassModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MulticlassModel, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)

    def forward(self, x):
        x = self.fc(x)
        return x

# Assuming you have data and labels
data = torch.randn(10, 10)  # Example input data
labels = torch.tensor([0, 2, 1, 0, 1, 2, 0, 1, 2, 0])  # Example labels

num_classes = 3
model = MulticlassModel(10, num_classes)
criterion = nn.CrossEntropyLoss()

# Forward pass
outputs = model(data)

# Calculate the loss
loss = criterion(outputs, labels)

# Backpropagation and optimization
loss.backward()
# optimizer.step()

In [None]:
outputs

tensor([[ 0.1644, -0.2507, -0.6007],
        [ 0.2412,  0.4681,  0.4865],
        [ 0.2217,  0.5646,  0.8483],
        [ 1.0216,  0.7531,  0.0209],
        [-0.0742,  0.1176, -0.3856],
        [-0.2104,  0.3781,  0.9974],
        [-0.8425, -0.0156, -0.1311],
        [ 0.6459, -0.5710, -0.5398],
        [-0.0431, -0.2920,  0.0311],
        [ 0.2606, -0.9355,  0.0621]], grad_fn=<AddmmBackward0>)

In [None]:
labels

tensor([0, 2, 1, 0, 1, 2, 0, 1, 2, 0])

# origin MGSSL multiclass

In [1]:

import os, time, random, shutil, math
import argparse
from argparse import ArgumentParser, Namespace
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.data import DataLoader
from loader import MoleculeDataset_grover, MoleculeDataset_other

from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler

from model import GNN_grover, GNN_graphpred, GNN_graphpred_grover
from util import calcul_loss, save_cp, confusion_mat, makedirs, create_logger
from splitters import scaffold_split, random_split

from rdkit import RDLogger
import logging
from logging import Logger

# i don't want see warning of torch dataset
import warnings
warnings.filterwarnings(action='ignore')


In [2]:
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.0001,
                    help='learning rate (default: 0.001)')
parser.add_argument('--lr_scale', type=float, default=1,
                    help='relative learning rate for the feature extraction layer (default: 1)')
parser.add_argument('--decay', type=float, default=1e-7,
                    help='weight decay (default: 0)')
parser.add_argument('--num_layer', type=int, default=5,
                    help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300,
                    help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.2,
                    help='dropout ratio (default: 0.5)')
parser.add_argument('--graph_pooling', type=str, default="mean",
                    help='graph level pooling (sum, mean, max, set2set, attention)')
parser.add_argument('--JK', type=str, default="last",
                    help='how the node features across layers are combined. last, sum, max or concat')
parser.add_argument('--gnn_type', type=str, default="gin")
parser.add_argument('--dataset', type=str, default = 'sider', help='root directory of dataset. For now, only classification.')
parser.add_argument('--data_path', type=str, default = 'dataset/', help='filename to read the model (if there is any)')
parser.add_argument('--output_path', type=str, default = 'output', help='output filename')
parser.add_argument('--seed', type=int, default=0, help = "Seed for splitting the dataset.")
parser.add_argument('--runseed', type=int, default=0, help = "Seed for minibatch selection, random initialization.")
parser.add_argument('--split', type = str, default="scaffold", help = "random or scaffold or random_scaffold")
parser.add_argument('--eval_train', type=int, default = 1, help='evaluating training or not')
parser.add_argument('--num_workers', type=int, default = 4, help='number of workers for dataset loading')
parser.add_argument('--input_model_file', type=str, default = 'pretrained.pth', help='filename to read the model (if there is any)')

# For search
parser.add_argument('--randomsearch', action='store_true', default=False, help='randomsearch mode')
#parser.add_argument('--gridsearch', action='store_true', default=False, help='gridsearch mode')
parser.add_argument('--n_iters', type=int, default=1,
                    help='Number of search')
parser.add_argument('--grover', action='store_true', default=False, help='use grover feature')
parser.add_argument('--regression', action='store_true', default=False, help='data is regression')
parser.add_argument('--multi_class', action='store_true', default=False, help='data is multi_class')
parser.add_argument('--num_tasks', type=int, default=1, help='number of tasks')
parser.add_argument('--num_class', type=int, default=2, help='number of class')

# For predict
parser.add_argument('--model_path', type=str, default = 'dataset/', help='filename to read the model (if there is any)')
parser.add_argument('--predict', action='store_true', default=False, help='only predicition')

args = parser.parse_args(['--data_path', 'dataset/' ,'--epochs', '100', '--batch_size', '96', '--grover', '--dataset', 'tg423', '--output_path', 'output/test'])
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

logger = create_logger(name='train', save_dir=args.output_path, quiet=False)

In [3]:
info = logger.info if logger is not None else print

torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

#set up dataset
dataset = MoleculeDataset_other(args.data_path + args.dataset, dataset=args.dataset)
args.num_tasks = len(dataset[0]['y'])
labels = pd.read_csv(args.data_path + args.dataset + '/raw/' + args.dataset + '.csv', header=None)[1][1:]
try : unique_labels = np.unique(labels[~labels.isnull()])
except : unique_labels = np.unique(labels[~labels.isnull()].astype(float))
args.num_class = len(unique_labels)
if args.num_class>2 :
    if args.num_tasks>1 and not args.regression: 
        raise ValueError("this model can't treat multi-task and multi-class")
    else:
        args.multi_class=True

if args.regression: 
    scaler = StandardScaler()
    scaler.fit(dataset.y.view(-1,1))
else:
    scaler=None

if args.split == "scaffold":
    smiles_list = pd.read_csv(args.data_path + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
    info(f'scaffold_balanced_split')
elif args.split == "random":
    train_dataset, valid_dataset, test_dataset = random_split(dataset, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
    info("random")
elif args.split == "random_scaffold":
    smiles_list = pd.read_csv(args.data_path + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = random_scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
    info("random scaffold")
else:
    raise ValueError("Invalid split option.")

info(f'total_size:{len(dataset)} train_size:{len(train_dataset)} val_size:{len(valid_dataset)} test_size:{len(test_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

#set up model
model = GNN_graphpred(args)
if not args.input_model_file==None:
    model.from_pretrained(args.input_model_file)


model.to(device)

#set up optimizer
#different learning rate for different part of GNN
model_param_group = []
model_param_group.append({"params": model.gnn.parameters()})
if args.graph_pooling == "attention":
    model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
info(optimizer)

best_val_loss = 999999
best_model_path = os.path.join(args.output_path, str(args.seed))

scaffold_balanced_split
total_size:1578 train_size:1262 val_size:158 test_size:158
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)


In [4]:
model.train()

loss_sum = 0
iter_count = 0

for step, batch in enumerate(train_loader):
    batch = batch.to(device)
    pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
    if args.multi_class and not args.regression:
        y = batch.y.view(pred.shape[0],1).to(torch.long)
    elif args.regression and args.dataset=='qm8':
        y = batch.y.view(pred.shape).to(torch.float)
        y = scaler.transform(y.cpu().view(-1,1))
        y = torch.tensor(y, dtype=torch.float).view(pred.shape).to(device)
    elif args.regression:
        y = batch.y.view(pred.shape).to(torch.float)
        y = scaler.transform(y.cpu().view(-1,1))
        y = torch.tensor(y, dtype=torch.float).to(device)
    else : 
        y = batch.y.view(pred.shape).to(torch.float64)
    if step==0:break

In [5]:
pred.shape

torch.Size([96, 4])

In [6]:
y.shape

torch.Size([96, 1])

In [7]:
true=y
loss_func = nn.CrossEntropyLoss()
is_valid = true**2 > 0
true = true[is_valid]+1
true = torch.tensor([x-1 if x != 0 else x for x in true]).cuda()
#loss = loss_func(pred, true)

In [9]:
true.shape

torch.Size([96])

In [22]:
loss_func = nn.CrossEntropyLoss()
is_valid = true**2 > 0
true = true[is_valid]+1
true = torch.tensor([x-1 if x != 0 else x for x in true]).cuda()
loss = loss_func(pred, true)

In [23]:
true=true.unsqueeze(dim=1).to(float)

In [24]:
true.shape

torch.Size([58, 1])

In [25]:
pred.shape

torch.Size([58, 1])

In [26]:
loss_func(pred, true)

tensor(-0., device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward1>)

In [31]:
class GNN_graphpred(torch.nn.Module):
    """
    Extension of GIN to incorporate edge information by concatenation.

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        num_tasks (int): number of tasks in multi-task learning scenario
        drop_ratio (float): dropout rate
        JK (str): last, concat, max or sum.
        graph_pooling (str): sum, mean, max, attention, set2set
        gnn_type: gin, gcn, graphsage, gat
        
    See https://arxiv.org/abs/1810.00826
    JK-net: https://arxiv.org/abs/1806.03536
    """
    def __init__(self, args):
        super(GNN_graphpred, self).__init__()
        self.num_layer = args.num_layer
        self.drop_ratio = args.drop_ratio
        self.JK = args.JK
        self.emb_dim = args.emb_dim
        self.num_tasks = args.num_tasks
        self.num_class = args.num_class
        self.gnn_type = args.gnn_type
        self.regression = args.regression
        self.graph_pooling = args.graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn = GNN(self.num_layer, self.emb_dim, self.JK, self.drop_ratio, gnn_type = self.gnn_type)

        #Different kind of graph pooling
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            if self.JK == "concat":
                self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * self.emb_dim, 1))
            else:
                self.pool = GlobalAttention(gate_nn = torch.nn.Linear(self.emb_dim, 1))
        elif graph_pooling[:-1] == "set2set":
            set2set_iter = int(graph_pooling[-1])
            if self.JK == "concat":
                self.pool = Set2Set((self.num_layer + 1) * self.emb_dim, set2set_iter)
            else:
                self.pool = Set2Set(self.emb_dim, set2set_iter)
        else:
            raise ValueError("Invalid graph pooling type.")

        #For graph-level binary classification
        if graph_pooling[:-1] == "set2set":
            self.mult = 2
        else:
            self.mult = 1
        
        if self.JK == "concat":
            if args.regression and args.num_tasks>1:
                self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
            elif args.regression:
                self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, 1)
            elif args.multi_class :
                self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_class)
            else : 
                self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
        else:
            if args.regression and args.num_tasks>1:
                self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)
            elif args.regression:
                self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, 1)
            elif args.multi_class :
                self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_class)
            else : 
                self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)

    def from_pretrained(self, model_file):
        #self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)
        self.gnn.load_state_dict(torch.load(model_file))

    def forward(self, *argv):
        if len(argv) == 4:
            x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        else:
            raise ValueError("unmatched number of arguments.")

        node_representation = self.gnn(x, edge_index, edge_attr)

        return self.graph_pred_linear(self.pool(node_representation, batch))


True

In [15]:
true.unsqueeze(dim=1).shape

torch.Size([58, 1])

In [10]:
pred.shape

torch.Size([58, 1])

In [5]:
#loss matrix after removing null target
loss = calcul_loss(pred, y, args)
print(loss)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [6]:
optimizer.zero_grad()
loss.backward()

optimizer.step()

loss_sum += loss
iter_count += 1

torch.cuda.empty_cache()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.