In [None]:
import networkx as nx
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdPartialCharges
from torch_geometric.data import DataLoader
from torch_geometric.utils import from_networkx
from molgraph.chemistry import Featurizer, features

atom_encoder = Featurizer([
    features.Symbol(),
    features.TotalNumHs(),
    features.ChiralCenter(),
    features.Aromatic(),
    features.Ring(),
    features.Hetero(),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.RingSize(),
    features.GasteigerCharge()
])


def make_pyg_graph(input_str, label,format="InChI"):
    if format == "InChI":
        mol = Chem.MolFromInchi(input_str)
    elif format == "SMILES":
        mol = Chem.MolFromSmiles(input_str)
    else:
        raise ValueError("format kwarg needs to InChI or SMILES")

    if not mol:
        return None

    rdPartialCharges.ComputeGasteigerCharges(mol)
    G = nx.DiGraph()

    for atom in mol.GetAtoms():
        G.add_node(atom.GetIdx(), x = atom_encoder(atom))

    for bond in mol.GetBonds():
        begin_atom = bond.GetBeginAtom()
        end_atom = bond.GetEndAtom()
        if float(begin_atom.GetProp('_GasteigerCharge')) > float(end_atom.GetProp('_GasteigerCharge')):
            G.add_edge(begin_atom.GetIdx(), end_atom.GetIdx())
                    #    bond_type=bond.GetBondType()
        else:
            G.add_edge(end_atom.GetIdx(), begin_atom.GetIdx())

    G = from_networkx(G)
    G.x = torch.FloatTensor(G.x)
    G.y = torch.LongTensor([label])

    return G

def make_train_data(train_csv):
    df = pd.read_csv(train_csv)
    train_data = [make_pyg_graph(i, j) for i,j in zip(df['InChI'], df['covalent'])]
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    return train_loader

# a = make_pyg_graph("InChI=1S/C6H6/c1-2-4-6-5-3-1/h1-6H")
train_data = make_train_data("./victor_data/test_data_all.csv")