In [None]:
import os
import glob

import gdown
from rdkit import Chem
from rdkit.Chem import PandasTools
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
import networkx as nx
import matplotlib.pyplot as plt

from intuitive_sc.utils.paths import DATA_PATH

In [None]:
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

## USPTO

### Load USPTO from Graph2SMILES

In [None]:
urls_fns_dict = {
    "USPTO_50k": [
        ("https://drive.google.com/uc?id=1pz-qkfeXzeD_drO9XqZVGmZDSn20CEwr", "src-train.txt"),
        ("https://drive.google.com/uc?id=1ZmmCJ-9a0nHeQam300NG5i9GJ3k5lnUl", "tgt-train.txt"),
        ("https://drive.google.com/uc?id=1NqLI3xpy30kH5fbVC0l8bMsMxLKgO-5n", "src-val.txt"),
        ("https://drive.google.com/uc?id=19My9evSNc6dlk9od5OrwkWauBpzL_Qgy", "tgt-val.txt"),
        ("https://drive.google.com/uc?id=1l7jSqYfIr0sL5Ad6TUxsythqVFjFudIx", "src-test.txt"),
        ("https://drive.google.com/uc?id=17ozyajoqPFeVjfViI59-QpVid1M0zyKN", "tgt-test.txt")
    ],
    "USPTO_full": [
        ("https://drive.google.com/uc?id=1PbHoIYbm7-69yPOvRA0CrcjojGxVCJCj", "src-train.txt"),
        ("https://drive.google.com/uc?id=1RRveZmyXAxufTEix-WRjnfdSq81V9Ud9", "tgt-train.txt"),
        ("https://drive.google.com/uc?id=1jOIA-20zFhQ-x9fco1H7Q10R6CfxYeZo", "src-val.txt"),
        ("https://drive.google.com/uc?id=19ZNyw7hLJaoyEPot5ntKBxz_o-_R14QP", "tgt-val.txt"),
        ("https://drive.google.com/uc?id=1ErtNB29cpSld8o_gr84mKYs51eRat0H9", "src-test.txt"),
        ("https://drive.google.com/uc?id=1kV9p1_KJm8EqK6OejSOcqRsO8DwOgjL_", "tgt-test.txt")
    ],
    "USPTO_480k": [
        ("https://drive.google.com/uc?id=1RysNBvB2rsMP0Ap9XXi02XiiZkEXCrA8", "src-train.txt"),
        ("https://drive.google.com/uc?id=1CxxcVqtmOmHE2nhmqPFA6bilavzpcIlb", "tgt-train.txt"),
        ("https://drive.google.com/uc?id=1FFN1nz2yB4VwrpWaBuiBDzFzdX3ONBsy", "src-val.txt"),
        ("https://drive.google.com/uc?id=1pYCjWkYvgp1ZQ78EKQBArOvt_2P1KnmI", "tgt-val.txt"),
        ("https://drive.google.com/uc?id=10t6pHj9yR8Tp3kDvG0KMHl7Bt_TUbQ8W", "src-test.txt"),
        ("https://drive.google.com/uc?id=1FeGuiGuz0chVBRgePMu0pGJA4FVReA-b", "tgt-test.txt")
    ],
    "USPTO_STEREO": [
        ("https://drive.google.com/uc?id=1r3_7WMEor7-CgN34Foj-ET-uFco0fURU", "src-train.txt"),
        ("https://drive.google.com/uc?id=1HUBLDtqEQc6MQ-FZQqNhh2YBtdc63xdG", "tgt-train.txt"),
        ("https://drive.google.com/uc?id=1WwCH8ASgBM1yOmZe0cJ46bj6kPSYYIRc", "src-val.txt"),
        ("https://drive.google.com/uc?id=19OsSpXxWJ-XWuDwfG04VTYzcKAJ28MTw", "tgt-val.txt"),
        ("https://drive.google.com/uc?id=1FcbWZnyixhptaO6DIVjCjm_CeTomiCQJ", "src-test.txt"),
        ("https://drive.google.com/uc?id=1rVWvbmoVC90jyGml_t-r3NhaoWVVSKLe", "tgt-test.txt")
    ]
}

In [None]:
dataset_name = 'USPTO_50k'
for url, fn in urls_fns_dict[dataset_name]:
        os.makedirs(os.path.join(DATA_PATH, 'raw_uspto', dataset_name), exist_ok=True)
        ofn = os.path.join(DATA_PATH, 'raw_uspto', dataset_name, fn)
        if not os.path.exists(ofn):
            gdown.download(url, ofn, quiet=False)
            assert os.path.exists(ofn)
        else:
            print(f"{ofn} exists, skip downloading")

In [None]:
raw_filepaths = glob.glob(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, '*.txt'))
uspto_data_prod = []
uspto_data_react = []
for filepath in raw_filepaths:
    if 'train' in filepath:
        split = 'train'
    elif 'val' in filepath:
        split = 'val'
    elif 'test' in filepath:
        split = 'test'
    if 'src' in filepath:
        moltype = 'product'
        with open(filepath, "r") as f:
            lines = f.readlines()
        # remove spaces between characters
        lines = [''.join(line.strip().split()) for line in lines]
        sub_df = pd.DataFrame(lines, columns=[moltype])
        sub_df['split'] = split
        uspto_data_prod.append(sub_df)
    elif 'tgt' in filepath:
        moltype = 'reactant'
        with open(filepath, "r") as f:
            lines = f.readlines()
        lines = [''.join(line.strip().split()) for line in lines]
        sub_df = pd.DataFrame(lines, columns=[moltype])
        sub_df['split'] = split
        uspto_data_react.append(sub_df)
    print(len(sub_df))
uspto_data_prod = pd.concat(uspto_data_prod)
uspto_data_react = pd.concat(uspto_data_react)

In [None]:
mols = []
n = 0
for smi in uspto_data_prod['product']:
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        n +=1
        print(smi)
    mols.append(mol)
print(n)

In [None]:
uspto_data_prod.head()

In [None]:
len(uspto_data_prod), len(uspto_data_react)

### Combine df

In [None]:
uspto_prod_train = uspto_data_prod[uspto_data_prod['split'] == 'train']
uspto_prod_val = uspto_data_prod[uspto_data_prod['split'] == 'val']
uspto_prod_test = uspto_data_prod[uspto_data_prod['split'] == 'test']
uspto_react_train = uspto_data_react[uspto_data_react['split'] == 'train']
uspto_react_val = uspto_data_react[uspto_data_react['split'] == 'val']
uspto_react_test = uspto_data_react[uspto_data_react['split'] == 'test']

In [None]:
# create combined df with matching train, test, val split column and matching index
uspto_train = pd.concat([uspto_react_train.drop(columns=['split']), uspto_prod_train], axis=1)
uspto_val = pd.concat([uspto_react_val.drop(columns=['split']), uspto_prod_val], axis=1)
uspto_test = pd.concat([uspto_react_test.drop(columns=['split']), uspto_prod_test], axis=1)
# combine train, test, val
uspto_data = pd.concat([uspto_train, uspto_val, uspto_test])
uspto_data.head()

### Analyze

There are instances where there are two reactants and sometimes the reactant is just CC.

In [None]:
mols = []
n_max = 1000
for i, row in uspto_train.iterrows():
    mol_prod = Chem.MolFromSmiles(row['product'])
    mol_reac = Chem.MolFromSmiles(row['reactant'])
    mols.append(mol_reac)
    mols.append(mol_prod)
    if i > n_max:
        break

In [None]:
Chem.Draw.MolsToGridImage(mols[:20], molsPerRow=2, subImgSize=(300, 300))

In [None]:
print('all unfiltered reactions: ', len(uspto_data))    
# remove rows with nan values
uspto_data = uspto_data.dropna()
print('after removing nan values: ', len(uspto_data))
# remove rows with empty strings
uspto_data = uspto_data[uspto_data['reactant'] != '']
uspto_data = uspto_data[uspto_data['product'] != '']
print('after removing empty strings: ', len(uspto_data))
# remove rows with reactant and product that are the same
uspto_data = uspto_data[uspto_data['reactant'] != uspto_data['product']]
print('after removing reactant and product that are the same: ', len(uspto_data))
# remove duplicates
uspto_data = uspto_data.drop_duplicates(keep='first')
print('after removing duplicates: ', len(uspto_data))

In [None]:
# split datapoints in two if there are two reactants (string has '.' in it)
uspto_data_split = []
for i, row in uspto_data.iterrows():
    if '.' in row['reactant']:
        reactants = row['reactant'].split('.')
        for reactant in reactants:
            sub_df = pd.DataFrame({'reactant': reactant, 'product': row['product']}, index=[i])
            uspto_data_split.append(sub_df)
    else:
        sub_df = pd.DataFrame({'reactant': row['reactant'], 'product': row['product']}, index=[i])
        uspto_data_split.append(sub_df)
uspto_data_split = pd.concat(uspto_data_split)
uspto_data_split.head()

In [None]:
# canonicalize smiles
uspto_data_split['reactant'] = uspto_data_split['reactant'].apply(lambda x: Chem.CanonSmiles(x))
uspto_data_split['product'] = uspto_data_split['product'].apply(lambda x: Chem.CanonSmiles(x))

In [None]:
# repeat filtering with canonicalized smiles before split!
print('all unfiltered reactions: ', len(uspto_data_split))
# remove rows with nan values
uspto_data_split = uspto_data_split.dropna()
print('after removing nan values: ', len(uspto_data_split))
# remove rows with empty strings
uspto_data_split = uspto_data_split[uspto_data_split['reactant'] != '']
uspto_data_split = uspto_data_split[uspto_data_split['product'] != '']
print('after removing empty strings: ', len(uspto_data_split))
# remove rows with reactant and product that are the same
uspto_data_split = uspto_data_split[uspto_data_split['reactant'] != uspto_data_split['product']]
print('after removing reactant and product that are the same: ', len(uspto_data_split))
# remove duplicates
uspto_data_split = uspto_data_split.drop_duplicates()
print('after removing duplicates: ', len(uspto_data_split))

In [None]:
# save to csv
uspto_data.to_csv(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, 'uspto_raw_combo.csv'))
uspto_data_split.to_csv(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, 'uspto_raw_split_combo.csv'))

In [None]:
# load data
uspto_data = pd.read_csv(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, 'uspto_raw_combo.csv'), index_col=0)
uspto_data_split = pd.read_csv(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, 'uspto_raw_split_combo.csv'), index_col=0)

#### Analyse graph from reactions

In [None]:
# plot the graph using plotly
def plot_graph(graph):
    pos = nx.spring_layout(graph, k=0.5, iterations=50)
    edge_x = []
    edge_y = []
    for edge in graph.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.3, color='black'),
        hoverinfo='none',
        mode='lines')

    node_x = []
    node_y = []
    for node in graph.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line_width=2))

    node_adjacencies = []
    node_text = []
    for node, adjacencies in enumerate(G.adjacency()):
        node_adjacencies.append(len(adjacencies[1]))
        node_text.append(adjacencies[0])

    node_trace.marker.color = node_adjacencies
    node_trace.text = node_text

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='USPTO Reaction Network',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20, l=5, r=5, t=40),
                        annotations=[dict(
                            text="",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002)],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))
    fig.show()

In [None]:
# create a list of of (reactant, product) tuples
reac_prod = list(zip(uspto_data_split['reactant'], uspto_data_split['product']))

In [None]:
digraph = nx.DiGraph()
digraph.add_edges_from(reac_prod)

In [None]:
# get random rows from dataframe and extract reactant and product
n_max = 1000
subgraph_smiles = []
for i, row in uspto_data_split.sample(n_max).iterrows():
    subgraph_smiles.append(row['reactant'])
    subgraph_smiles.append(row['product'])
G = digraph.subgraph(subgraph_smiles)
plot_graph(G)

### Remove molecules with less than 4 heavy atoms

In [None]:
PandasTools.AddMoleculeColumnToFrame(uspto_data_split, smilesCol='reactant', molCol='reactant_mol')
PandasTools.AddMoleculeColumnToFrame(uspto_data_split, smilesCol='product', molCol='product_mol')

In [None]:
# get rows that have reactants or products with less than 4 heavy atoms
uspto_data_split['reactant_natoms'] = uspto_data_split['reactant_mol'].apply(lambda x: x.GetNumHeavyAtoms())
uspto_data_split['product_natoms'] = uspto_data_split['product_mol'].apply(lambda x: x.GetNumHeavyAtoms())

In [None]:
small_mols = uspto_data_split[(uspto_data_split['reactant_natoms'] < 4) | (uspto_data_split['product_natoms'] < 4)]
len(small_mols)

In [None]:
# remove rows that have reactants or products with less than 4 heavy atoms
print(len(uspto_data_split))
uspto_data_split_fil = uspto_data_split[(uspto_data_split['reactant_natoms'] >= 4) & (uspto_data_split['product_natoms'] >= 4)]
print(len(uspto_data_split_fil))

In [None]:
reac_prod_remove = list(zip(small_mols['reactant'], small_mols['product']))

In [None]:
[node for node in digraph.nodes if digraph.degree(node) == 0]

In [None]:
# remove edges between reac_prod_remove
digraph.remove_edges_from(reac_prod_remove)

In [None]:
lonely_nodes = [node for node in digraph.nodes if digraph.degree(node) == 0]

In [None]:
# remove lonely nodes
digraph.remove_nodes_from(lonely_nodes)

In [None]:
# get random rows from dataframe and extract reactant and product
n_max = 1000
subgraph_smiles = []
for i, row in uspto_data_split_fil.sample(n_max).iterrows():
    subgraph_smiles.append(row['reactant'])
    subgraph_smiles.append(row['product'])
G = digraph.subgraph(subgraph_smiles)
plot_graph(G)

In [None]:
uspto_data_split_fil.to_csv(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, 'uspto_split_combo_fil_withloops.csv'))

### Remove cycles

In [None]:
uspto_data_split_fil = pd.read_csv(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, 'uspto_split_combo_fil_withloops.csv'), index_col=0)
reac_prod = list(zip(uspto_data_split_fil['reactant'], uspto_data_split_fil['product']))
digraph = nx.DiGraph()
digraph.add_edges_from(reac_prod)

In [None]:
nx.find_cycle(digraph, orientation="original")

In [None]:
def remove_bicycle(digraph, df, self_loops=[]):
    '''
    Remove bicycle (two nodes connected to each other) from a directed graph recursively.
    '''
    try:
        cycles = nx.find_cycle(digraph, orientation="original")
    except nx.NetworkXNoCycle:
        return digraph, df, self_loops
    for cycle in cycles:
        # check if cycle is a self-loop (only connected to each other = remove)
        if len(list(digraph.neighbors(cycle[0]))) == 1 and list(digraph.neighbors(cycle[0]))[0] == cycle[1]:
            # remove edges
            digraph.remove_edge(cycle[0], cycle[1])
            digraph.remove_edge(cycle[1], cycle[0])
            # remove nodes
            digraph.remove_node(cycle[0])
            digraph.remove_node(cycle[1])
            df = df[~((df['reactant'] == cycle[0]) & (df['product'] == cycle[1]))]
            df = df[~((df['reactant'] == cycle[1]) & (df['product'] == cycle[0]))]
            self_loops.append((cycle[0], cycle[1]))
            break
    return remove_bicycle(digraph, df, self_loops)

new_digraph, uspto_data_split_fil_lin, self_loops = remove_bicycle(digraph, uspto_data_split_fil)

In [None]:
# write function above as for loop due to recursion limit
def remove_bicycle(cycles, digraph, df, self_loops=[]):
    '''
    Remove bicycle (two nodes connected to each other) from a directed graph.
    '''
    for cycle in cycles:
        # check if cycle is a self-loop (only connected to each other = remove)
        if len(list(digraph.neighbors(cycle[0]))) == 1 and list(digraph.neighbors(cycle[0]))[0] == cycle[1]:
            # remove edges
            digraph.remove_edge(cycle[0], cycle[1])
            digraph.remove_edge(cycle[1], cycle[0])
            # remove nodes
            digraph.remove_node(cycle[0])
            digraph.remove_node(cycle[1])
            df = df[~((df['reactant'] == cycle[0]) & (df['product'] == cycle[1]))]
            df = df[~((df['reactant'] == cycle[1]) & (df['product'] == cycle[0]))]
            self_loops.append((cycle[0], cycle[1]))
            break
    return digraph, df, self_loops

# uspto_data_split_fil_lin = uspto_data_split_fil.copy()
# while True:
#     try:
#         cycles = nx.find_cycle(digraph, orientation="original")
#         digraph, uspto_data_split_fil_lin, self_loops = remove_bicycle(cycles, digraph, uspto_data_split_fil_lin)
#     except nx.NetworkXNoCycle:
#         break

In [None]:
len(uspto_data_split_fil_lin), len(uspto_data_split_fil)

In [None]:
nx.find_cycle(digraph, orientation="original")

In [None]:
# protection groups and hydroxy <-> aldehyde/ketone
n = 1
Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(self_loops[n][0]),Chem.MolFromSmiles(self_loops[n][1]) ], molsPerRow=2, subImgSize=(300, 300))

In [None]:
nx.dag_longest_path_length(digraph)

In [None]:
longest_graph = nx.dag_longest_path(digraph)
Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(x) for x in longest_graph], subImgSize=(300, 300))

In [None]:
uspto_data_split_fil_lin.to_csv(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, 'uspto_split_combo_fil.csv'), index=False)

## Final graph analysis

In [None]:
# get some stats on digraph
print(nx.info(digraph))
# 3 highest degree
print(sorted(digraph.degree, key=lambda x: x[1], reverse=True)[:3])
# get the longest path
nx.dag_longest_path_length(digraph)

In [None]:
# count number of nodes with degree 1
len([node for node in digraph.nodes if digraph.degree(node) == 1]), len([node for node in digraph.nodes if digraph.out_degree(node) == 1]), len([node for node in digraph.nodes if digraph.in_degree(node) == 1])

In [None]:
# plot distribution of heavy atoms
# XXX far too small molecules!
plt.hist(uspto_data_split_fil_lin['reactant_natoms'], bins=20, alpha=0.5, label='reactant');
plt.hist(uspto_data_split_fil_lin['product_natoms'], bins=20, alpha=0.5, label='product');
plt.legend();

### Format for input
* add label/target column
* randomize position of reactants and products
* columns
    * smiles_i
    * smiles_j
    * target (what is more complex)

In [None]:
import random
# seed
random.seed(42)
def format_rows(row):
    # randomize reactants and products
    if random.random() > 0.5:
        return row['reactant'], row['product'], 1
    else:
        return row['product'], row['reactant'], 0
    
uspto_reorder = pd.DataFrame([format_rows(row) for _, row in uspto_data_split_fil_lin.iterrows()], columns=['smiles_i', 'smiles_j', 'target'])

In [None]:
uspto_reorder.head()

In [None]:
uspto_reorder.to_csv(os.path.join(DATA_PATH, 'raw_uspto', dataset_name, 'uspto_input_trial.csv'), index=False)