## Table of contents
* [Define parameters](#parameters)
* [Class definitions](#class)
* [Function definitions](#function)
* [Run](#run)

In [21]:
from pathlib import Path
import pickle
from typing import Tuple,  Dict

import numpy as np
import pandas as pd

from tqdm import tqdm
import typer
from typing import List, Optional

from torch import FloatTensor, LongTensor
from torch import flatten, device, cuda, nn, from_numpy
from torch import load as load_module
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional, Module

from sklearn.preprocessing import LabelEncoder

import dask.dataframe as dd
from dask.diagnostics import ProgressBar

from drfp import DrfpEncoder

## Define parameters <a class="anchor" id="parameters"></a>

In step 2 we have seen that the <code>ec123_drfp_mlp</code> model has the best accuracy. We will therefore use this model here to make the EC predictions.

In [22]:
modeltype='ec123_drfp_mlp'

In [23]:
Path('experiments/predictions').mkdir(exist_ok=True)

## Class definitions <a class="anchor" id="class"></a>

In [24]:
class InferenceReactionDataset(Dataset):
    def __init__(self, rxns: List, label: str = "label"):
        self.rxns = rxns
        self.size = len(rxns)
        self.label = label
        
        fps, _, _ = DrfpEncoder.encode(
            rxns,
            mapping=True,
            atom_index_mapping=True,
            root_central_atom=False,
            radius=2,
            include_hydrogens=True,
            n_folded_length=10240,
        )

        self.X = FloatTensor(
            np.array([x.astype(np.float32) for x in fps], dtype=np.float32)
        )

    def __getitem__(self, i):
        return self.X[i]

    def __len__(self):
        return self.size

In [25]:
class MLPClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLPClassifier, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.fc1 = nn.Linear(self.input_size, self.hidden_size)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, x):
        hidden = self.fc1(x)
        tanh = self.tanh(hidden)
        output = self.fc2(tanh)
        return output

In [26]:
class ReactionDataset(Dataset):
    def __init__(self, df: pd.DataFrame, label: str = "label"):
        self.size = len(df)
        self.label = label
        self.X = FloatTensor(
            np.array([x.astype(np.float32) for x in df.fps], dtype=np.float32)
        )
        self.y = LongTensor(df[self.label].to_numpy(dtype=np.int32))
        self.fps = df["fps"]
        self.rxn_smiles = df["rxn_smiles"]

    def __getitem__(self, i):
        return (
            self.X[i],
            self.y[i],
        )

    def __len__(self):
        return self.size

## Function definitions <a class="anchor" id="function"></a>

In [27]:
def get_device():
    if cuda.is_available():
        return device("cuda:0")
    return device("cpu")

In [28]:
def load_models(
    source: str = "rheadb", names = [], split: str = "0", fplength: int = 10240
) -> Dict[str, Tuple[MLPClassifier, LabelEncoder, Dataset]]:

    models = {}

    for name in names:
        model_path = f"models/{source}-{name}.pt"

        classifier = None
        label_encoder = None

        with open(f"models/{source}-{name}-le.pkl", "rb") as f:
            label_encoder: LabelEncoder = pickle.load(f)

        classifier = MLPClassifier(fplength, 1664, len(label_encoder.classes_))
        classifier.load_state_dict(load_module(model_path))
        classifier.eval()

        models[name] = (classifier, label_encoder)

    return models

In [29]:
def predict_internal(
    model: Module,
    device: device,
    data_set: Dataset,
    label_encoder: LabelEncoder,
    topk: int = 10,
) -> Tuple[str, Dict[str, float], List[int]]:
    data_sample = next(iter(DataLoader(data_set)))
    data_sample = data_sample.to(device)
    pred_raw = model(data_sample)
    probs = flatten(functional.softmax(pred_raw, dim=1)).cpu().detach().numpy()
    pred = pred_raw.max(1, keepdim=True)[1]
    y_pred = flatten(pred).tolist()

    topk_indices = (-probs).argsort()[:topk]
    probabilities = {
        label_encoder.inverse_transform([i])[0]: prob for i, prob in enumerate(probs)
    }

    return label_encoder.inverse_transform(y_pred)[0], probabilities, topk_indices

In [30]:
def predict_one(
    rxn, model, label_encoder, device, probs, topk, dataset
):
    
    pred, probabilities, topk_indices = predict_internal(
        model, device, dataset, label_encoder, topk
    )

    result = [pred]

    if probs:
        top_k_classes = [label_encoder.inverse_transform([i])[0] for i in topk_indices]
        result.append({c: probabilities[c] for c in top_k_classes})

    return result

In [31]:
def predict(
    model_id: str,
    rxn_smiles: str,
    topk: Optional[int] = 5,
    explain: Optional[bool] = False,
    probs: Optional[bool] = False,
):
    dataset = InferenceReactionDataset([rxn_smiles])
    
    vals = model_id.split(".")
    models = load_models(vals[0], [vals[1]], '0', len(dataset[0]))

    device = get_device()
    model, label_encoder = models[vals[1]]
    model = model.to(device)

    result = predict_one(
        rxn_smiles,
        model,
        label_encoder,
        device,
        probs,
        topk,
        dataset
    )

    return result[1]

In [32]:
def predict_EC_with_probs(
    model: str = typer.Argument(
        ...,
        help="The name of the model. Options are rheadb.ec1, rheadb.ec12, rheadb.ec123, ecreact.ec1, ecreact.ec12, and ecreact.ec123.",
    ),
    rxn_smiles: str = typer.Argument(
        ..., help="The reaction smiles in the form a.b>>c.d."
    ),
    topk: Optional[int] = 5,
    probs: Optional[bool] = True,
):
    val = predict(model, rxn_smiles, topk, False, probs)
    val = dict(filter(lambda elem: elem[1]>0.01, val.items()))
    return val


In [33]:
# Test for one smiles:
predict_EC_with_probs(f'rheadb.{modeltype}', "CCCCC(N)=O.[H]O[H]>>CCCCC(=O)[O-].[H][N+]([H])([H])[H]")

{'3.5.1': 0.99916303}

In [34]:
def predictECs(row):
    try:
        res = predict_EC_with_probs(f'rheadb.{modeltype}', row['rxn_smiles'])
        return ' | '.join([f"{key}:{value}" for key, value in res.items()])
    except Exception as e:
        #print(e)
        return 'No prediction'

In [35]:
def main(infile: str, outfile: str):
    df = pd.read_csv(infile)

    # Replace ecreact style reactions with standard reaction SMILES.
    df["rxn_smiles"] = df["rxn_smiles"].str.replace(r"\|.*>", ">>", regex=True)

    ddata = dd.from_pandas(df, npartitions=1000)
    ProgressBar().register()
    res = ddata.map_partitions(
        lambda df: df.assign(EC_prediction=df.apply(predictECs, axis=1))).compute()
    
    res.to_csv(outfile, index=False, sep='\t')
    print('=> Created file', outfile)
    print('=> Finished')
    
# Alternative main function with pandas instead of dask - slower!

# def main(infile: str, outfile: str):
#     df = pd.read_csv(infile)

#     # Replace ecreact style reactions with standard reaction SMILES.
#     df["rxn_smiles"] = df["rxn_smiles"].str.replace(r"\|.*>", ">>", regex=True)

#     df_nan = df[~df.ec.notna()]

#     tqdm.pandas()

#     df_nan['EC predictions | probabilities'] = df_nan.progress_apply(predictECs, axis=1)
#     df_nan.to_csv(outfile, index=False, sep='\t')

## Run <a class="anchor" id="run"></a>

In [36]:
main('data/rheadb.csv.gz', f'experiments/predictions/rheadb_predicted_ECs_{modeltype}.tsv')

[###                                     ] | 9% Completed |  8min 50.2s
[###                                     ] | 9% Completed |  8min 50.4s


KeyboardInterrupt: 