# Analysis

In [139]:
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import os
import torch 
from torch_geometric.data import Batch, Data
import re
import csv
import json
import pickle
from pathlib import Path
from scipy.stats import percentileofscore
from collections import defaultdict, Counter
import math
import numpy as np
from rdkit import Chem
from tqdm import tqdm
from argparse import Namespace
from p_tqdm import p_map
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from functools import partial
from lightning_fabric.utilities.cloud_io import _load as pl_load
import clipzyme.utils.loading as loaders
from clipzyme.utils.loading import get_object
from clipzyme.utils.smiles import remove_atom_maps
from clipzyme.utils.pyg import from_mapped_smiles

On Metrics:
* bedroc85: top 3.5% contribute to 95% of the score (10K of 261,907 enzymes)
* bedroc20: top 8% contribute to 80% of the score (21K of 261,907 enzymes)

### Collect screening set hidden representation

```
import torch 
import pickle
from p_tqdm import p_map
from functools import partial

def read_protein(u, path):
    try:
        d = pickle.load(
            open(f"/home/results/metabolomics/clip/{path}/sample_{u}.hiddens", 'rb')
        )
        return d['hidden']
    except:
        return
 
screening_set= pickle.load(open("uniprot2sequence_standard_set_structs.p", "rb"))
screening_set = {k:v for k,v in screening_set.items() if v!= ""}
# has structure and msa
alphafold_files = pickle.load(open("alphafold_enzymes.p", "rb"))
msa_files = pickle.load(open("uniprot2msa_embedding.p", "rb"))
screening_set = {k:v for k,v in screening_set.items() if (k in alphafold_files) and (k in msa_files) and (len(v)<=650)}
len(screening_set)
screening_set_uniprots = list(screening_set.keys())

experiment_name = "4add9a242ca4f896cd31da4d0d129c63epoch=8"

read_protein_func = partial(read_protein, path = experiment_name)
hiddens = p_map(read_protein_func, screening_set_uniprots)

all_ec_uniprots_ = [u for u,h in zip(screening_set_uniprots, hiddens) if h is not None]
hiddens = [h for h in hiddens if h is not None]
hiddens = torch.stack(hiddens)

pickle.dump({'hiddens': hiddens, 'uniprots': all_ec_uniprots_}, open(f"precomputed_{experiment_name}.p",'wb'))
```

## functions

In [140]:
BOND_TYPE = {
    1: Chem.rdchem.BondType.SINGLE,
    2: Chem.rdchem.BondType.DOUBLE,
    3: Chem.rdchem.BondType.TRIPLE,
    1.5: Chem.rdchem.BondType.AROMATIC,
}

clean_rxns_postsani = [
    # two adjacent aromatic nitrogens should allow for H shift
    Chem.AllChem.ReactionFromSmarts("[n;H1;+0:1]:[n;H0;+1:2]>>[n;H0;+0:1]:[n;H0;+0:2]"),
    # two aromatic nitrogens separated by one should allow for H shift
    Chem.AllChem.ReactionFromSmarts(
        "[n;H1;+0:1]:[c:3]:[n;H0;+1:2]>>[n;H0;+0:1]:[*:3]:[n;H0;+0:2]"
    ),
    Chem.AllChem.ReactionFromSmarts("[#7;H0;+:1]-[O;H1;+0:2]>>[#7;H0;+:1]-[O;H0;-:2]"),
    # neutralize C(=O)[O-]
    Chem.AllChem.ReactionFromSmarts(
        "[C;H0;+0:1](=[O;H0;+0:2])[O;H0;-1:3]>>[C;H0;+0:1](=[O;H0;+0:2])[O;H1;+0:3]"
    ),
    # turn neutral halogens into anions EXCEPT HCl
    Chem.AllChem.ReactionFromSmarts("[I,Br,F;H1;D0;+0:1]>>[*;H0;-1:1]"),
    # inexplicable nitrogen anion in reactants gets fixed in prods
    Chem.AllChem.ReactionFromSmarts("[N;H0;-1:1]([C:2])[C:3]>>[N;H1;+0:1]([*:2])[*:3]"),
]



def robust_edit_mol(rmol, edits):
    """Simulate reaction via graph editing

    Parameters
    ----------
    rmol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance for the reactants
    bond_changes : list of 3-tuples
        Each tuple is of form (atom1, atom2, change_type)
    keep_atom_map : bool
        Whether to keep atom mapping number. Default to False.

    Returns
    -------
    pred_smiles : list of str
        SMILES for the edited molecule
    """

    new_mol = Chem.RWMol(rmol)

    # Keep track of aromatic nitrogens, might cause explicit hydrogen issues
    aromatic_nitrogen_idx = set()
    aromatic_carbonyl_adj_to_aromatic_nH = {}
    aromatic_carbondeg3_adj_to_aromatic_nH0 = {}
    for a in new_mol.GetAtoms():
        if a.GetIsAromatic() and a.GetSymbol() == "N":
            aromatic_nitrogen_idx.add(a.GetIdx())
            for nbr in a.GetNeighbors():
                if (
                    a.GetNumExplicitHs() == 1
                    and nbr.GetSymbol() == "C"
                    and nbr.GetIsAromatic()
                    and any(b.GetBondTypeAsDouble() == 2 for b in nbr.GetBonds())
                ):
                    aromatic_carbonyl_adj_to_aromatic_nH[nbr.GetIdx()] = a.GetIdx()
                elif (
                    a.GetNumExplicitHs() == 0
                    and nbr.GetSymbol() == "C"
                    and nbr.GetIsAromatic()
                    and len(nbr.GetBonds()) == 3
                ):
                    aromatic_carbondeg3_adj_to_aromatic_nH0[nbr.GetIdx()] = a.GetIdx()
        else:
            a.SetNumExplicitHs(0)
    new_mol.UpdatePropertyCache()

    amap = {}
    for atom in rmol.GetAtoms():
        amap[atom.GetIntProp("molAtomMapNumber")] = atom.GetIdx()  # new index to old index

    # Apply the edits as predicted
    for x, y, t in edits:
        bond = new_mol.GetBondBetweenAtoms(amap[x], amap[y])
        a1 = new_mol.GetAtomWithIdx(amap[x])
        a2 = new_mol.GetAtomWithIdx(amap[y])
        if bond is not None:
            new_mol.RemoveBond(amap[x], amap[y])

            # Are we losing a bond on an aromatic nitrogen?
            if bond.GetBondTypeAsDouble() == 1.0:
                if amap[x] in aromatic_nitrogen_idx:
                    if a1.GetTotalNumHs() == 0:
                        a1.SetNumExplicitHs(1)
                    elif a1.GetFormalCharge() == 1:
                        a1.SetFormalCharge(0)
                elif amap[y] in aromatic_nitrogen_idx:
                    if a2.GetTotalNumHs() == 0:
                        a2.SetNumExplicitHs(1)
                    elif a2.GetFormalCharge() == 1:
                        a2.SetFormalCharge(0)

            # Are we losing a c=O bond on an aromatic ring? If so, remove H from adjacent nH if appropriate
            if bond.GetBondTypeAsDouble() == 2.0:
                if amap[x] in aromatic_carbonyl_adj_to_aromatic_nH:
                    new_mol.GetAtomWithIdx(
                        aromatic_carbonyl_adj_to_aromatic_nH[amap[x]]
                    ).SetNumExplicitHs(0)
                elif amap[y] in aromatic_carbonyl_adj_to_aromatic_nH:
                    new_mol.GetAtomWithIdx(
                        aromatic_carbonyl_adj_to_aromatic_nH[amap[y]]
                    ).SetNumExplicitHs(0)

        if t > 0:
            new_mol.AddBond(amap[x], amap[y], BOND_TYPE[t])

            # Special alkylation case?
            if t == 1:
                if amap[x] in aromatic_nitrogen_idx:
                    if a1.GetTotalNumHs() == 1:
                        a1.SetNumExplicitHs(0)
                    else:
                        a1.SetFormalCharge(1)
                elif amap[y] in aromatic_nitrogen_idx:
                    if a2.GetTotalNumHs() == 1:
                        a2.SetNumExplicitHs(0)
                    else:
                        a2.SetFormalCharge(1)

            # Are we getting a c=O bond on an aromatic ring? If so, add H to adjacent nH0 if appropriate
            if t == 2:
                if amap[x] in aromatic_carbondeg3_adj_to_aromatic_nH0:
                    new_mol.GetAtomWithIdx(
                        aromatic_carbondeg3_adj_to_aromatic_nH0[amap[x]]
                    ).SetNumExplicitHs(1)
                elif amap[y] in aromatic_carbondeg3_adj_to_aromatic_nH0:
                    new_mol.GetAtomWithIdx(
                        aromatic_carbondeg3_adj_to_aromatic_nH0[amap[y]]
                    ).SetNumExplicitHs(1)

    pred_mol = new_mol.GetMol()

    # Clear formal charges to make molecules valid
    # Note: because S and P (among others) can change valence, be more flexible
    for atom in pred_mol.GetAtoms():
        # atom.ClearProp("molAtomMapNumber")
        if (
            atom.GetSymbol() == "N" and atom.GetFormalCharge() == 1
        ):  # exclude negatively-charged azide
            bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
            if bond_vals <= 3:
                atom.SetFormalCharge(0)
        elif (
            atom.GetSymbol() == "N" and atom.GetFormalCharge() == -1
        ):  # handle negatively-charged azide addition
            bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
            if bond_vals == 3 and any(
                [nbr.GetSymbol() == "N" for nbr in atom.GetNeighbors()]
            ):
                atom.SetFormalCharge(0)
        elif atom.GetSymbol() == "N":
            bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
            if (
                bond_vals == 4 and not atom.GetIsAromatic()
            ):  # and atom.IsInRingSize(5)):
                atom.SetFormalCharge(1)
        elif atom.GetSymbol() == "C" and atom.GetFormalCharge() != 0:
            atom.SetFormalCharge(0)
        elif atom.GetSymbol() == "O" and atom.GetFormalCharge() != 0:
            bond_vals = (
                sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
                + atom.GetNumExplicitHs()
            )
            if bond_vals == 2:
                atom.SetFormalCharge(0)
        elif atom.GetSymbol() in ["Cl", "Br", "I", "F"] and atom.GetFormalCharge() != 0:
            bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
            if bond_vals == 1:
                atom.SetFormalCharge(0)
        elif atom.GetSymbol() == "S" and atom.GetFormalCharge() != 0:
            bond_vals = sum([bond.GetBondTypeAsDouble() for bond in atom.GetBonds()])
            if bond_vals in [2, 4, 6]:
                atom.SetFormalCharge(0)
        elif (
            atom.GetSymbol() == "P"
        ):  # quartenary phosphorous should be pos. charge with 0 H
            bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
            if sum(bond_vals) == 4 and len(bond_vals) == 4:
                atom.SetFormalCharge(1)
                atom.SetNumExplicitHs(0)
            elif sum(bond_vals) == 3 and len(bond_vals) == 3:  # make sure neutral
                atom.SetFormalCharge(0)
        elif atom.GetSymbol() == "B":  # quartenary boron should be neg. charge with 0 H
            bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
            if sum(bond_vals) == 4 and len(bond_vals) == 4:
                atom.SetFormalCharge(-1)
                atom.SetNumExplicitHs(0)
        elif atom.GetSymbol() in ["Mg", "Zn"]:
            bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
            if sum(bond_vals) == 1 and len(bond_vals) == 1:
                atom.SetFormalCharge(1)
        elif atom.GetSymbol() == "Si":
            bond_vals = [bond.GetBondTypeAsDouble() for bond in atom.GetBonds()]
            if sum(bond_vals) == len(bond_vals):
                atom.SetNumExplicitHs(max(0, 4 - len(bond_vals)))

    # Bounce to/from SMILES to try to sanitize
    pred_smiles = Chem.MolToSmiles(pred_mol)
    pred_list = pred_smiles.split(".")
    pred_mols = [Chem.MolFromSmiles(pred_smiles) for pred_smiles in pred_list]

    for i, mol in enumerate(pred_mols):
        # Check if we failed/succeeded in previous step
        if mol is None:
            # print('##### Unparseable mol: {}'.format(pred_list[i]))
            continue

        # Else, try post-sanitiztion fixes in structure
        mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
        if mol is None:
            continue
        for rxn in clean_rxns_postsani:
            out = rxn.RunReactants((mol,))
            if out:
                try:
                    Chem.SanitizeMol(out[0][0])
                    pred_mols[i] = Chem.MolFromSmiles(Chem.MolToSmiles(out[0][0]))
                except Exception as e:
                    pass
                    # print(e)
                    # print('Could not sanitize postsani reaction product: {}'.format(Chem.MolToSmiles(out[0][0])))
                    # print('Original molecule was: {}'.format(Chem.MolToSmiles(mol)))
    pred_smiles = [
        Chem.MolToSmiles(pred_mol) for pred_mol in pred_mols if pred_mol is not None
    ]

    return pred_smiles

In [141]:
def print_metrics(metrics, percentiles, exclude_label=None):
    if exclude_label:
        b85 = [j for i,j in enumerate(metrics["bedroc85"]) if (not exclude_label[i]) and (not np.isnan(j)) ]
        b20 = [j for i,j in enumerate(metrics["bedroc20"]) if (not exclude_label[i]) and (not np.isnan(j)) ]
        ef1 = [j for i,j in enumerate(metrics["ef0.1"]) if (not exclude_label[i]) and (not np.isnan(j)) ]
        ef05 = [j for i,j in enumerate(metrics["ef0.05"]) if (not exclude_label[i]) and (not np.isnan(j)) ]
        percs = [j for i,j in enumerate(percentiles) if not exclude_label[i]]
    else:
        b85 = metrics["bedroc85"]
        b20 = metrics["bedroc20"]
        ef1 = metrics["ef0.1"]
        ef05 = metrics["ef0.05"]
        percs = percentiles
        
    print(
        """
        * mean BEDROC_85:\t {}
        * mean BEDROC_20:\t {}
        * mean EF_0.05:\t {}
        * mean EF_0.1:\t {}
        * mean percentile:\t {}
        * median percentile:\t {}
        """.format(
            np.mean(b85), 
            np.mean(b20),
            np.mean(ef05), 
            np.mean(ef1),
            np.mean(percs),
            np.median(percs)
        )
    )

In [143]:
def bedroc_score(y_true, y_pred, decreasing=True, alpha=20.0):

    """BEDROC metric implemented according to Truchon and Bayley.

    The Boltzmann Enhanced Descrimination of the Receiver Operator
    Characteristic (BEDROC) score is a modification of the Receiver Operator
    Characteristic (ROC) score that allows for a factor of *early recognition*.

    References:
        The original paper by Truchon et al. is located at `10.1021/ci600426e
        <http://dx.doi.org/10.1021/ci600426e>`_.

    Args:
        y_true (array_like):
            Binary class labels. 1 for positive class, 0 otherwise.
        y_pred (array_like):
            Prediction values.
        decreasing (bool):
            True if high values of ``y_pred`` correlates to positive class.
        alpha (float):
            Early recognition parameter.

    Returns:
        float:
            Value in interval [0, 1] indicating degree to which the predictive
            technique employed detects (early) the positive class.
     """

    assert len(y_true) == len(y_pred), \
        'The number of scores must be equal to the number of labels'

    big_n = len(y_true)
    n = sum(y_true == 1)

    if decreasing:
        order = np.argsort(-y_pred)
    else:
        order = np.argsort(y_pred)

    m_rank = (y_true[order] == 1).nonzero()[0]

    s = np.sum(np.exp(-alpha * m_rank / big_n))

    r_a = n / big_n

    rand_sum = r_a * (1 - np.exp(-alpha))/(np.exp(alpha/big_n) - 1)

    fac = r_a * np.sinh(alpha / 2) / (np.cosh(alpha / 2) -
                                      np.cosh(alpha/2 - alpha * r_a))

    cte = 1 / (1 - np.exp(alpha * (1 - r_a)))

    return s * fac / rand_sum + cte

In [144]:
def enrichment_score(y_true, y_pred, decreasing=True, chi=0.1):
    """Enrichment Factor metric

        How many more actives we find within a defined 
        “early recognition” fraction of the ordered 
        list relative to a random distribution


              n_actives_in_sampled_set / n_sampled_set      n_actives_in_sampled_set
        EF =  ----------------------------------------  = ----------------------------
                   total_actives / total_ligands               chi * total_actives

    Args:
        y_true (array_like):
            Binary class labels. 1 for positive class, 0 otherwise.
        y_pred (array_like):
            Prediction values.
        decreasing (bool):
              True if high values of ``y_pred`` correlates to positive class.
      
    Returns:
        float:
            Value in interval [0, tau]
                tau = 1/chi if chi >= n/N
                tau = N/n if chi < n/N
     """
    big_n = len(y_true)
    n = sum(y_true == 1)
    
    if decreasing:
        order = np.argsort(-y_pred)
    else:
        order = np.argsort(y_pred)

    k = math.floor(chi * big_n)
    
    num_in_topk = (y_true[order] == 1)[:k].sum()

    ef = (num_in_topk) / (chi * n)
        
    return ef

In [145]:
def read_clean_prediction(path):
    u2ec = {}
    result = open(path, 'r')
    csvreader = csv.reader(result, delimiter=',')
    pred_label = []
    for row in csvreader:
        preds_ec_lst = set()
        uni = row[0]
        preds_with_dist = row[1:]
        for pred_ec_dist in preds_with_dist:
            # get EC number 3.5.2.6 from EC:3.5.2.6/10.8359
            ec_i = pred_ec_dist.split(":")[1].split("/")[0]
            preds_ec_lst.add(ec_i)
        u2ec[uni] = preds_ec_lst
    return u2ec

In [146]:
def read_protein(u, path):
    try:
        d = pickle.load(
            open(f"/home/results/metabolomics/clip/{path}/sample_{u}.hiddens", 'rb')
        )
        return d['hidden']
    except:
        return

## Screening Set

In [None]:
screening_set= pickle.load(open("/home/uniprot2sequence_standard_set_structs.p", "rb"))
print(f"Original screening set size: {len(screening_set)}")

# remove empty
screening_set = {k:v for k,v in screening_set.items() if v!= ""}

# has structure and msa
alphafold_files = pickle.load(open("/home/alphafold_enzymes.p", "rb"))
msa_files = pickle.load(open("/home/uniprot2msa_embedding.p", "rb"))

screening_set = {k:v for k,v in screening_set.items() if (k in alphafold_files) and (k in msa_files) and (len(v)<=650)}
print(f"Final screening set size: {len(screening_set)}")
# as list
screening_set_uniprots = list(screening_set.keys())

----------------

## load data

In [148]:
# load checkpoint and args
args_path = '/home/logs/metabo/bf6b607124c5cca3430fc0c2ee1148dd.args'
args = Namespace(**pickle.load(open(args_path,'rb')))

In [149]:
args.cache_path= None

In [None]:
train_dataset = get_object(args.dataset_name, 'dataset')(args, 'train')
val_dataset = get_object(args.dataset_name, 'dataset')(args, 'dev')
test_dataset = get_object(args.dataset_name, 'dataset')(args, 'test')

In [151]:
train_uniprots = set(d['uniprot_id'] for d in  train_dataset.dataset)
train_reactions = set(d['reaction_string'] for d in  train_dataset.dataset)
train_ecs = set(d['ec'] for d in  train_dataset.dataset)

train_ecs1 = set(d['ec1'] for d in  train_dataset.dataset)
train_ecs2 = set(d['ec2'] for d in  train_dataset.dataset)
train_ecs3 = set(d['ec3'] for d in  train_dataset.dataset)
train_ecs4 = set(d['ec4'] for d in  train_dataset.dataset)

In [152]:
train_rules = set([d['rule_id'] for d in train_dataset.dataset])

In [153]:
# full dataset
full_dataset = json.load(open(args.dataset_file_path, 'r'))

In [154]:
reaction2unis = defaultdict(set)
for d in full_dataset:
    if d['protein_db'] not in ["swissprot", "uniprot"]:
        continue

    # without maps
    r = '.'.join(sorted(d['reactants'])) + '>>' + '.'.join(sorted(d['products']))
    for p in eval(d['protein_refs']):
        reaction2unis[r].add(p)

In [None]:
len(reaction2unis)

In [156]:
cached_dataset = pickle.load(open(args.dataset_cache_path,'rb'))

In [None]:
Counter([d['ec'][0] for d in cached_dataset])

In [None]:
len(set([d['rule_id'] for d in cached_dataset]))

In [None]:
len(train_dataset.uniprot2sequence)

In [None]:
percentileofscore([len(v) for k,v in train_dataset.uniprot2sequence.items() if v is not None], 650)

In [180]:
# general
ec2uniprots = pickle.load(open('/home/Brenda/ec2uniprot_2023_1.p', 'rb'))

In [206]:
ec2uniprots_levels = {}
for level in [1, 2, 3, 4]:
    s2ecs_ = defaultdict(set)
    for ec, us in ec2uniprots.items():
        e_ = '.'.join(ec.split('.')[:level])
        s2ecs_[e_].update( us )
    ec2uniprots_levels[level] = s2ecs_

In [None]:
for k, v in ec2uniprots_levels[1].items():
    print(f"EC {k}: {len(v)}")

In [None]:
for level in [1, 2, 3, 4]:
    avgsize = np.mean([len(v) for v in ec2uniprots_levels[level].values() ])
    print(f"Level {level}: {avgsize}")

## load model and screening set

In [159]:
checkpoint_path = '/home/snapshots/metabolomics/bf6b607124c5cca3430fc0c2ee1148dd/bf6b607124c5cca3430fc0c2ee1148ddepoch=26.ckpt'
checkpoint = torch.load(checkpoint_path)
args = checkpoint['hyper_parameters']['args']
args.from_checkpoint = True
args.checkpoint_path = checkpoint_path

In [160]:
ckptpath = Path(args.checkpoint_path)
screenset = pickle.load(open(f"precomputed_{ckptpath.stem}.p", 'rb'))

In [161]:
DEVICE = "cuda:1"

In [162]:
args.do_ec_task = False
args.train_esm_dir = '/home/snapshots/metabolomics/esm2/checkpoints/esm2_t30_150M_UR50D.pt'

In [None]:
try:
    model = loaders.get_lightning_model(args)
except:
    model = get_object(args.lightning_name, "lightning")(args)
    checkpoint = pl_load(args.checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint['state_dict'])
model.eval()
model = model.to(DEVICE)
print(model.training)

In [169]:
screen_hiddens = screenset["hiddens"]
screen_unis = screenset["uniprots"]

## Analysis

### table 1: ablation

In [182]:
screen_unis_set = set(screen_unis)

In [183]:
model.model.args.use_as_mol_encoder = True 

In [184]:
test_dataset.args.ec_levels_one_hot = False

In [None]:
len(screen_hiddens)

In [None]:
percentiles = []
metrics = defaultdict(list)
data_samples = []
data_sample_ids = []
all_substrate_features = []
reaction_set = set()

with torch.no_grad():
    for ref_id in tqdm(range(len(test_dataset)), ncols=100, total=len(test_dataset)):
    
        rxn_uni = test_dataset.dataset[ref_id]['reaction_string']
        if rxn_uni in train_reactions:
            continue 
            
        if rxn_uni in reaction_set:
            continue 
    
        reaction_set.add(rxn_uni)
        
        data_samples.append( test_dataset.dataset[ref_id] )
        data_sample_ids.append(ref_id)
        
        ref_sample = loaders.default_collate([test_dataset[ref_id]])
        for key in ['reactants', 'products', 'graph', 'mol']:
            if key in ref_sample:
                ref_sample[key] = ref_sample[key].to(DEVICE)
        out = model.model(ref_sample)
        substrate_features = out['hidden']
        substrate_features = substrate_features.cpu()
        all_substrate_features.append(substrate_features)

        ref_unis = reaction2unis[ test_dataset.dataset[ref_id]['reaction_string']  ]

        scores = screen_hiddens @ substrate_features.T
        scores = scores.squeeze().tolist()
        
        labels = [ int(u in reaction2unis[ test_dataset.dataset[ref_id]['reaction_string'] ]) for u in screen_unis]
        
        # compute metrics
        labels = np.array(labels)
        scores = np.array(scores)
        metrics["bedroc85"].append( bedroc_score(labels, scores, decreasing=True, alpha=85.0) )
        metrics["bedroc20"].append( bedroc_score(labels, scores, decreasing=True, alpha=20.0) )
        metrics["ef0.1"].append( enrichment_score(labels, scores, decreasing=True, chi=0.1) )
        metrics["ef0.05"].append( enrichment_score(labels, scores, decreasing=True, chi=0.05) )
        

        # percentile
        refscore = max(s for s,l in zip(scores,labels) if l)
        percentiles.append( percentileofscore(scores, refscore) )
        

In [188]:
all_substrate_features = torch.concat(all_substrate_features)

In [None]:
all_substrate_features.shape

In [228]:
sim_matrix = all_substrate_features @ screen_hiddens.T

In [49]:
label = [ (d['uniprot_id'] in train_uniprots) for d in data_samples]

In [None]:
print_metrics(metrics, percentiles, exclude_label=label)

### table 5: within EC

In [247]:
ec2uniprot = {k:set(v) for k,v in train_dataset.ec2uniprot.items()}

In [250]:
ec2uniprot_levels = {k: defaultdict(set) for k in [1,2,3,4]}
for level in [1,2,3,4]:
    for ec_, u_ in ec2uniprot.items():
        ec2uniprot_levels[level]['.'.join(ec_.split('.')[:level])].update(u_)

In [None]:
# how well can rank uniprots in an EC for a given reaction

data_samples_ec = []
percentiles = []
metrics = defaultdict(list)
reaction_set = set()

for level in [1]
    percentiles = []
    data_samples_ec = []
    # metrics = defaultdict(list)
    reaction_set = set()
    for idx, ref_sample in tqdm(enumerate(data_samples), ncols=100, total=len(data_samples)):
    
        if ref_sample['reaction_string'] in reaction_set:
            continue
        
        if  train_dataset.ec2uniprot.get( ref_sample['ec'], None) is None:
            continue
    
        # take only the sequences in EC
        ec_ = '.'.join(ref_sample['ec'].split('.')[:level])
        ec_unis = [u for u in screen_unis if u in ec2uniprot_levels[level][ec_] ]
        # label scores
        labels = [ int(u in reaction2unis[ ref_sample['reaction_string'] ]) for u in ec_unis]
        if not any(labels):
            continue 
    
        reaction_set.add(ref_sample['reaction_string'])
        data_samples_ec.append(ref_sample)
        
        scores = sim_matrix[idx].squeeze().numpy()  
    
        scores = [s for s, u in zip(scores, screen_unis) if u in ec2uniprot_levels[level][ec_]  ]
    
        ref_score = max( scores[i] for i,l in enumerate(labels) if l)
        
        metrics["bedroc85"].append( bedroc_score(np.array(labels), np.array(scores), decreasing=True, alpha=85.0) )
        metrics["bedroc20"].append( bedroc_score(np.array(labels), np.array(scores), decreasing=True, alpha=20.0) )
        metrics["ef0.1"].append( enrichment_score(np.array(labels), np.array(scores), decreasing=True, chi=0.1) )
        metrics["ef0.05"].append( enrichment_score(np.array(labels), np.array(scores), decreasing=True, chi=0.05) )
        
        percentiles.append( 
            percentileofscore(scores, ref_score ) 
        )
    print(f"Level: {level}")
    print_metrics(metrics, percentiles)

### table 2a: EC-based screening using distances

In [190]:
def get_ec_id_dict(csv_name: str) -> dict:
    csv_file = open(csv_name)
    csvreader = csv.reader(csv_file, delimiter="\t")
    id_ec = {}
    ec_id = {}

    for i, rows in enumerate(csvreader):
        if i > 0:
            id_ec[rows[0]] = rows[1].split(";")
            for ec in rows[1].split(";"):
                if ec not in ec_id.keys():
                    ec_id[ec] = set()
                    ec_id[ec].add(rows[0])
                else:
                    ec_id[ec].add(rows[0])
    return id_ec, ec_id

In [191]:
def get_cluster_center(model_emb, ec_id_dict):
    cluster_center_model = {}
    id_counter = 0
    with torch.no_grad():
        for ec in tqdm(list(ec_id_dict.keys())):
            ids_for_query = list(ec_id_dict[ec])
            id_counter_prime = id_counter + len(ids_for_query)
            emb_cluster = model_emb[id_counter:id_counter_prime]
            cluster_center = emb_cluster.mean(dim=0)
            cluster_center_model[ec] = cluster_center.detach().cpu()
            id_counter = id_counter_prime
    return cluster_center_model

In [193]:
# CLEAN embeddings for screening set
clean_screening_set_embed = torch.load('CLEAN_screening_set_embed.p', map_location='cpu')

# reference sequences for computing EC anchors / clusters
clean_split100 = pickle.load(open('CLEAN_split100_embed.p', 'rb'))
clean_split100_uni2embed = {u:h for u,h in zip(clean_split100['uniprots'], clean_split100['clean_embeddings'])}

In [194]:
# row identifiers
row_uniprots = pd.read_csv("uniprot2sequence_standard_set_structs.csv", delimiter='\t')
row_uniprots = list(row_uniprots['Entry'])

In [195]:
# order screen embeds in same order as screen_unis
clean_screening_uni_2_embed = { u: h for u,h in zip(row_uniprots, clean_screening_set_embed) }
clean_screening_set_embed_ = [clean_screening_uni_2_embed[u] for u in screen_unis]
clean_screening_set_embed = torch.vstack(clean_screening_set_embed_)

In [196]:
# get cluster centers
id_ec_train, ec_id_dict_train = get_ec_id_dict("CLEAN_split100.csv")

In [197]:
clean_ecs = list(ec_id_dict_train.keys())

In [198]:
# get mapping of EC -> UniProts for each level
clean_ec2unis_level = {}
for level in [1,2,3,4]:
    clean_ec2unis_level[level] = defaultdict(set)
    for k,v in ec_id_dict_train.items():
        # get ec level
        ec_l = '.'.join(k.split('.')[:level])
        clean_ec2unis_level[level][ec_l].update(v)

In [199]:
# for each level, compute distance matrix to the anchors

eval_dist_levels = {}
for level in [1,2,3,4]:
    
    # get cluster anchors
    model_lookup_ec_keys, model_lookup = [], []
    for k, uni_in_ec in clean_ec2unis_level[level].items():
        embeds_ = torch.vstack([clean_split100_uni2embed[u] for u in uni_in_ec])
        model_lookup.append(embeds_.mean(0))
        model_lookup_ec_keys.append(k)
    
    model_lookup = torch.vstack(model_lookup)
    
    # distance: columns are ECs
    eval_dist = torch.cdist(clean_screening_set_embed, model_lookup).detach()
    eval_dist_levels[level] = (eval_dist, model_lookup_ec_keys)

In [None]:
"""STRATEGY

Given a reaction with EC 1.1.1.x: 
    - If anchor 1.1.1 exists:
        Take distance of screening enzymes to that EC and use to rank
    - Else:
        Take distances to 1.1 (or higher until EC is found -- guaranteed to find highest level)
"""


level = 4

reaction_set = set()
percentiles = []
metrics = defaultdict(list)

for ref_id in tqdm(range(len(test_dataset)), ncols=100, total=len(test_dataset)):

    # get sample
    ref_sample = test_dataset.dataset[ref_id]

    # check if reaction was in train set
    rxn_uni = ref_sample['reaction_string']
    if rxn_uni in train_reactions:
        continue 

    # check if reaction was already evaluated in this loop
    if rxn_uni in reaction_set:
        continue 

    reaction_set.add(rxn_uni)

    # get ec, level
    ec = ref_sample["ec"]

    # if EC is not within split100 dataset, then check above 
    for prev_level in range(level, 0, -1):
        ec_ = '.'.join(ec.split('.')[:prev_level])
        level_eval_dist, level_model_lookup_ec_keys = eval_dist_levels[prev_level]
        if ec_ in level_model_lookup_ec_keys:
            break
                          
    lookup_col_idx = level_model_lookup_ec_keys.index(ec_)
    scores = level_eval_dist[:,lookup_col_idx].numpy()
    labels = np.array([ int(u in reaction2unis[rxn_uni]) for u in screen_unis])
    
    # # compute BEDROC        
    metrics["bedroc85"].append( bedroc_score(labels, scores, decreasing=False, alpha=85.0) )
    metrics["bedroc20"].append( bedroc_score(labels, scores, decreasing=False, alpha=20.0) )
    metrics["ef0.1"].append( enrichment_score(labels, scores, decreasing=False, chi=0.1) )
    metrics["ef0.05"].append( enrichment_score(labels, scores, decreasing=False, chi=0.05) )

    # # percentile
    # refscore = max(s for s,l in zip(scores,labels) if l)
    # percentiles.append( percentileofscore(scores, refscore) )

In [None]:
print(f"LEVEL: {level}")
print_metrics(metrics, [0])

### table 2b: re-ranked EC with CLIPZyme

In [178]:
clean_prediction = read_clean_prediction('CLEAN_uniprot2sequence_standard_set_structs_maxsep.csv')

In [179]:
screen_unis_set = set(screen_unis)
clean_ec2uniprots = {i: defaultdict(set) for i in [1,2,3,4]}
for u, ecs in clean_prediction.items():
    if u not in screen_unis_set:
        continue
    for ec in ecs:
        for level in [1,2,3,4]:
            ec_lvl = '.'.join(ec.split('.')[:level])
            clean_ec2uniprots[level][ec_lvl].add(u)

In [180]:
screen_unis_array = np.array(screen_unis)

In [332]:
sim_matrix = sim_matrix.numpy()

In [None]:
"""STRATEGY:

Use EC prediction to create ranked list of enzymes:

Given a reaction with EC 1.1.1.x, create a list of enzymes according to:

    - [ [proteins in 1.1.1] + [proteins in 1.1] + [proteins in 1] + [other] ]
    - re-rank with CLIPZyme each susbet

"""

level = 4

percentiles = []
metrics = defaultdict(list)
reaction_set = set()

for ref_id, ref_sample in tqdm(enumerate(data_samples), ncols=100, total=len(data_samples)):

    # check if reaction was in train set
    rxn_uni = ref_sample['reaction_string']
    if rxn_uni in train_reactions:
        continue 

    # check if reaction was already evaluated in this loop
    if rxn_uni in reaction_set:
        continue 

    reaction_set.add(rxn_uni)
    
    screening_scores = sim_matrix[ref_id] # rxn by enzymes

    # sort once at the beginning
    sorted_indices = np.argsort(-screening_scores) # high score should have low index to placed at the beginning of the list
    sorted_screen_unis = screen_unis_array[sorted_indices]

    
    # put predictions together in ordered list and rank with CLIPZyme at each step
    # this will determine final rank
    predictions_uniprot = []
    predictions_sofar = set()
    for lvl in range(level, 0, -1):
        # get ec, level
        ec = '.'.join( ref_sample["ec"].split('.')[:lvl] )
        
        # get uniprots annotated for ec, exclude those already added to list
        ec_proteins = clean_ec2uniprots[lvl].get(ec, set()) - predictions_sofar
        predictions_sofar.update(ec_proteins)
        
        # rank
        if len(ec_proteins):
            ec_proteins_list = list(ec_proteins)
            proteins_select_mask = np.isin(sorted_screen_unis, ec_proteins_list) 
            proteins_select = sorted_screen_unis[proteins_select_mask]
            predictions_uniprot.extend(proteins_select)

    
    # rank the rest
    remaining_proteins = list(screen_unis_set - predictions_sofar)

    proteins_select_mask = np.isin(sorted_screen_unis, remaining_proteins)
    proteins_select = sorted_screen_unis[proteins_select_mask]
    predictions_uniprot.extend(proteins_select)
            
    labels =  np.isin(predictions_uniprot, list(reaction2unis[rxn_uni]))
    ranks = np.arange(len(labels))[::-1]

    
    # # # compute BEDROC, EF
    metrics["bedroc85"].append( bedroc_score(labels, ranks, decreasing=True, alpha=85.0) )
    metrics["bedroc20"].append( bedroc_score(labels, ranks, decreasing=True, alpha=20.0) )
    metrics["ef0.1"].append( enrichment_score(labels, ranks, decreasing=True, chi=0.1) )
    metrics["ef0.05"].append( enrichment_score(labels, ranks, decreasing=True, chi=0.05) )


In [None]:
print(f"Level: {level}")
print_metrics(metrics, [0])

### table 8: re-ranked EC with CLEAN distances 

In [None]:
"""STRATEGY:

Use EC prediction to create ranked list of enzymes:

Given a reaction with EC 1.1.1.x, create a list of enzymes according to:

    - [ [proteins in 1.1.1] + [proteins in 1.1] + [proteins in 1] + [other] ]
    - re-rank with EC distances within each susbet

"""

level = 1

percentiles = []
metrics = defaultdict(list)
reaction_set = set()

for ref_id, ref_sample in tqdm(enumerate(data_samples), ncols=100, total=len(data_samples)):

    # check if reaction was in train set
    rxn_uni = ref_sample['reaction_string']
    if rxn_uni in train_reactions:
        continue 

    # check if reaction was already evaluated in this loop
    if rxn_uni in reaction_set:
        continue 

    reaction_set.add(rxn_uni)

    # distance most specific ec is used:
    # if rxn ec = 1.1.1.1 and it exists in CLEAN anchor then use distances to it
    # if it doesn't exist, try 1.1.1 and use distances to it
    ec_found = False
    for lvl in range(level, 0, -1):
        ec = '.'.join( ref_sample["ec"].split('.')[:lvl] )
        level_eval_dist, level_model_lookup_ec_keys = eval_dist_levels[lvl]
        if ec in level_model_lookup_ec_keys:
            ec_found = True
            lookup_col_idx = level_model_lookup_ec_keys.index(ec)
            screening_scores = level_eval_dist[:,lookup_col_idx].numpy()
        if ec_found:
            break

    
    # sort once at the beginning
    sorted_indices = np.argsort(screening_scores) # low score (distance) should have low index to placed at the beginning of the list
    sorted_screen_unis = screen_unis_array[sorted_indices]

    
    # put predictions together in ordered list and rank with CLIPZyme at each step
    # this will determine final rank
    predictions_uniprot = []
    predictions_sofar = set()
    for lvl in range(level, 0, -1):
        # get ec, level
        ec = '.'.join( ref_sample["ec"].split('.')[:lvl] )
        
        # get uniprots annotated for ec, exclude those already added to list
        ec_proteins = clean_ec2uniprots[lvl].get(ec, set()) - predictions_sofar
        predictions_sofar.update(ec_proteins)
        
        # rank
        if len(ec_proteins):
            ec_proteins_list = list(ec_proteins)
            proteins_select_mask = np.isin(sorted_screen_unis, ec_proteins_list) 
            proteins_select = sorted_screen_unis[proteins_select_mask]
            predictions_uniprot.extend(proteins_select)

    
    # rank the rest
    remaining_proteins = list(screen_unis_set - predictions_sofar)

    proteins_select_mask = np.isin(sorted_screen_unis, remaining_proteins)
    proteins_select = sorted_screen_unis[proteins_select_mask]
    predictions_uniprot.extend(proteins_select)
            
    labels =  np.isin(predictions_uniprot, list(reaction2unis[rxn_uni]))
    ranks = np.arange(len(labels))[::-1]

    
    # # # compute BEDROC, EF
    metrics["bedroc85"].append( bedroc_score(labels, ranks, decreasing=True, alpha=85.0) )
    metrics["bedroc20"].append( bedroc_score(labels, ranks, decreasing=True, alpha=20.0) )
    metrics["ef0.1"].append( enrichment_score(labels, ranks, decreasing=True, chi=0.1) )
    metrics["ef0.05"].append( enrichment_score(labels, ranks, decreasing=True, chi=0.05) )


In [None]:
print(f"Level: {level}")
print_metrics(metrics, [0])

### table 3: other reactions

#### unannotated enzymemap reactions

In [293]:
unannotated_reaction2ec = defaultdict(set)
unannotated_dataset = []
for d in full_dataset:
    if not isinstance(d['protein_db'], str):
        assert np.isnan(d['protein_db'])
        # without map
        r = '.'.join(sorted(d['reactants'])) + '>>' + '.'.join(sorted(d['products']))
        # EXCLUDE IF REACTION IN DATASET
        # EXCLUDE TRAIN REACTION RULES
        if (r not in reaction2unis) and (d['rule_id'] not in train_rules):
            unannotated_reaction2ec[r].add(d['ec'])
            unannotated_dataset.append(d)

In [None]:
len(unannotated_dataset), len(unannotated_reaction2ec)

In [None]:
len(set([d['ec'] for d in unannotated_dataset]))

In [None]:
print(f"ECs per reaction: {Counter([len(v) for k,v in unannotated_reaction2ec.items() ])}")

In [297]:
# need ec2uniprots
ec2uniprots = pickle.load(open('/home/Brenda/ec2uniprot_2023_1.p', 'rb'))

In [298]:
# UniProt to EC classes, by level
screenuni2ecs = {}
for level in [1,2,3,4]:
    s2ecs_ = defaultdict(set)
    for ec, us in ec2uniprots.items():
        for u in us:
            s2ecs_[u].add(  '.'.join(ec.split('.')[:level])   )
    screenuni2ecs[level] = s2ecs_

In [None]:
len(screenuni2ecs[level])

In [None]:
# get unannotated samples
data_samples = []

for ref_id, ref_sample in tqdm(enumerate(unannotated_dataset), ncols=100, total=len(unannotated_dataset)):

    rxn_uni = '.'.join(sorted(ref_sample['reactants'])) + '>>' + '.'.join(sorted(ref_sample['products']))

    if rxn_uni in train_reactions:
        continue 

    # check if reaction was already evaluated in this loop
    if rxn_uni in reaction_set:
        continue 
    
    reaction_set.add(rxn_uni)
    data_samples.append(ref_sample)

In [None]:
len(data_samples)

In [203]:
# process through rxn encoder
batch_size = 5
batches = [data_samples[i:(i+batch_size)] for i in range(0, len(data_samples), batch_size)]

In [None]:
all_substrate_features = []

with torch.no_grad():
    for batch in tqdm(batches, ncols=100, total=len(batches)):
        bsamples = []
        for ref_sample in batch:
    
            reactants = sorted([s for s in ref_sample['mapped_reactants'] if s != "[H+]"])
            products = sorted([s for s in ref_sample['mapped_products'] if s != "[H+]"])
            # products = [p for p in products if p not in reactants]
            
            reactants_graph, _ = from_mapped_smiles('.'.join(reactants), encode_no_edge=True)        
            products_graph, _ = from_mapped_smiles('.'.join(products), encode_no_edge=True)
    
            if reactants_graph.x.shape != products_graph.x.shape:
                z = 1/0
            item = {
                    "reactants": reactants_graph,
                    "products": products_graph,
                }
            
            bsamples.append(item)
                
        bsamples = loaders.default_collate(bsamples)
        for key in ['reactants', 'products', 'graph', 'mol']:
            if key in bsamples:
                bsamples[key] = bsamples[key].to(DEVICE)
        out = model.model(bsamples)
        substrate_features = out['hidden']
        substrate_features = substrate_features.cpu()
        all_substrate_features.append(substrate_features)

In [None]:
len(data_samples)

In [212]:
all_substrate_features = torch.concat(all_substrate_features)

In [213]:
# similarity matrix between rxns and screening hiddens
sim_matrix_ec = all_substrate_features @ screen_hiddens.T

In [None]:
sim_matrix_ec.shape

In [None]:
# can loosen to higher levels
level = 4

percentiles = []
metrics = defaultdict(list)
reaction_set = set()

for ref_id, ref_sample in tqdm(enumerate(data_samples), ncols=100, total=len(data_samples)):

    # get reaction
    rxn_uni = '.'.join(sorted(ref_sample['reactants'])) + '>>' + '.'.join(sorted(ref_sample['products']))

    # if in train (shouldn't be) then skip
    if rxn_uni in train_reactions:
        continue 

    # check if reaction was already evaluated in this loop
    if rxn_uni in reaction_set:
        continue 

    # add so we skip next time
    reaction_set.add(rxn_uni)

    # get scores for this rxn
    scores = sim_matrix_ec[ref_id].squeeze().numpy()

    # screenuni2ecs
    # ecs annotated for this reaction
    ecs = set( '.'.join(e.split('.')[:level]) for e in unannotated_reaction2ec[rxn_uni] )
    # label of: screening enzyme belongs to reaction ec
    labels = [ bool(screenuni2ecs[level][u].intersection(ecs)) for u in screen_unis]

    if not sum(labels):
        continue
    
    # compute BEDROC        
    labels, scores = np.array(labels), np.array(scores)

    metrics["bedroc85"].append( bedroc_score(labels, scores, decreasing=True, alpha=85.0) )
    metrics["bedroc20"].append( bedroc_score(labels, scores, decreasing=True, alpha=20.0) )
    metrics["ef0.1"].append( enrichment_score(labels, scores, decreasing=True, chi=0.1) )
    metrics["ef0.05"].append( enrichment_score(labels, scores, decreasing=True, chi=0.05) )

    # percentile
    # refscore = max(s for s,l in zip(scores, labels) if l)
    # percentiles.append( percentileofscore(scores, refscore) )
print_metrics(metrics, [0])

#### terpenes

In [165]:
synthases = json.load(open("/home/IOCB/TPS-rxn-carbon-mapped-curated-activity.json", 'r'))

In [None]:
terpene_reaction2unis = defaultdict(set)
for ref_id, synthase_sample in tqdm(enumerate(synthases), total=len(synthases), ncols=100):

    uni = synthase_sample['Uniprot ID']

    if uni not in screen_unis_set:
        continue
        
    reactants, _ = synthase_sample["rxn_smiles"].split(">>")
    reactants_= reactants.split(".")

    edits = [ (int(x), int(y), t) for x,y,t in synthase_sample['bond_changes'] ]
    rmol = Chem.MolFromSmiles(reactants)
    products_ = robust_edit_mol(rmol, edits)
    products = '.'.join(products_)
    
    std_reactants = sorted([remove_atom_maps(s) for s in reactants_])
    std_products = sorted([remove_atom_maps(s) for s in products_])
    rxn_uni = f"{'.'.join(std_reactants)}>>{'.'.join(std_products)}"
    terpene_reaction2unis[rxn_uni].add(uni)

In [167]:
terpene_reaction2unis_train = {k:v for k,v in terpene_reaction2unis.items() if any(p in train_uniprots for p in v)}

In [None]:
len(terpene_reaction2unis)

In [None]:
len(terpene_reaction2unis_train)

In [None]:
percentiles = []
metrics = defaultdict(list)
data_samples = []
data_sample_ids = []
all_substrate_features = []
reaction_set = set()
synthes = set()

screen_unis_set = set(screen_unis)
with torch.no_grad():
    for ref_id, synthase_sample in tqdm(enumerate(synthases), total=len(synthases), ncols=100):

        uni = synthase_sample['Uniprot ID']

        if uni not in screen_unis_set:
            continue

        # random 1 sample error
        if ":999" in synthase_sample["rxn_smiles"]:
            print("skipping 999")
            continue 
            
        reactants, _ = synthase_sample["rxn_smiles"].split(">>")
        reactants_= reactants.split(".")

        edits = [ (int(x), int(y), t) for x,y,t in synthase_sample['bond_changes'] ]
        rmol = Chem.MolFromSmiles(reactants)
        products_ = robust_edit_mol(rmol, edits)
        products = '.'.join(products_)
        pmol =  Chem.MolFromSmiles(products)
        
        std_reactants = sorted([remove_atom_maps(s) for s in reactants_])
        std_products = sorted([remove_atom_maps(s) for s in products_])
        rxn_uni = f"{'.'.join(std_reactants)}>>{'.'.join(std_products)}"

        # IF DO NOT LOOK AT REACTIONS WITH UNI IN TRAI
        if rxn_uni in terpene_reaction2unis_train:
            continue

        # check if reaction in train
        if rxn_uni in train_reactions:
            continue 

        # check in already evaluated
        if rxn_uni in reaction_set:
            continue 

        # check atom map is not overused incorrectly
        atommapnumbers_r = set()
        for atom in rmol.GetAtoms():
            atommapnumbers_r.add(atom.GetProp("molAtomMapNumber"))
        atommapnumbers_p = set()
        for atom in pmol.GetAtoms():
            atommapnumbers_p.add(atom.GetProp("molAtomMapNumber"))
        if (len(atommapnumbers_p) != pmol.GetNumAtoms()) or (len(atommapnumbers_r) != rmol.GetNumAtoms()):
            continue 
            
        reactants_graph, _ = from_mapped_smiles(reactants, encode_no_edge=True)        
        products_graph, _ = from_mapped_smiles(products, encode_no_edge=True)

        if reactants_graph.x.shape != products_graph.x.shape:
            continue

        reaction_set.add(rxn_uni)
        synthes.add(uni)
        
        data_samples.append( synthase_sample )
        data_sample_ids.append(ref_id)

        
        item = {
            "reactants": reactants_graph,
            "products": products_graph,
        }
        
        ref_sample = loaders.default_collate([item])
        for key in ['reactants', 'products', 'graph', 'mol']:
            if key in ref_sample:
                ref_sample[key] = ref_sample[key].to(DEVICE)
        out = model.model(ref_sample)
        substrate_features = out['hidden']
        substrate_features = substrate_features.cpu()
        all_substrate_features.append(substrate_features)
        
        ref_unis = terpene_reaction2unis[rxn_uni]

        scores = screen_hiddens @ substrate_features.T
        scores = scores.squeeze().tolist()
        
        labels = [ int(u in ref_unis) for u in screen_unis]
        
        # compute metrics   
        labels, scores = np.array(labels), np.array(scores)
        metrics["bedroc85"].append( bedroc_score(labels, scores, decreasing=True, alpha=85.0) )
        metrics["bedroc20"].append( bedroc_score(labels, scores, decreasing=True, alpha=20.0) )
        metrics["ef0.1"].append( enrichment_score(labels, scores, decreasing=True, chi=0.1) )
        metrics["ef0.05"].append( enrichment_score(labels, scores, decreasing=True, chi=0.05) )
        
        # percentile
        refscore = max(s for s,l in zip(scores,labels) if l)
        percentiles.append( percentileofscore(scores, refscore) )

In [None]:
len(reaction_set), len(synthes)

In [None]:
len(metrics["ef0.05"])

In [None]:
print_metrics(metrics, percentiles)

## Appendix

### Without train uniprots

In [None]:
# must remove reactions with train uni as well!

In [51]:
nottrain_hiddens = torch.vstack([h for h,u in zip(screen_hiddens, screen_unis) if not (u in train_uniprots)])
nottrain_uniprots = [u for u in screen_unis if not (u in train_uniprots)]

In [53]:
sim_matrix = (all_substrate_features @ nottrain_hiddens.T)

In [None]:
reaction_set = set()
percentiles = []
metrics = defaultdict(list)

for idx, ref_sample in tqdm(enumerate(data_samples), ncols=100, total=len(data_samples)):

    if ref_sample['reaction_string'] in reaction_set:
        continue
    
    labels = [ int(u in reaction2unis[ ref_sample['reaction_string'] ]) for u in nottrain_uniprots]
    if not any(labels):
        continue 

    reaction_set.add(ref_sample['reaction_string'])

    scores = sim_matrix[idx].squeeze().tolist()
    ref_score = max( scores[i] for i,l in enumerate(labels) if l)
    
    percentiles.append( percentileofscore(scores, ref_score ) )

    metrics["bedroc85"].append( bedroc_score(np.array(labels), np.array(scores), decreasing=True, alpha=85.0) )
    metrics["bedroc20"].append( bedroc_score(np.array(labels), np.array(scores), decreasing=True, alpha=20.0) )
    metrics["ef0.1"].append( enrichment_score(np.array(labels), np.array(scores), decreasing=True, chi=0.1) )
    metrics["ef0.05"].append( enrichment_score(np.array(labels), np.array(scores), decreasing=True, chi=0.05) )


In [None]:
print_metrics(metrics, percentiles)

In [None]:
len(metrics["bedroc85"])

### by protein families: mmseqs

In [120]:
# mmseqs
mmseqs_clusters = pickle.load(open('/home/EnzymeMap/mmseq_clusters_updated.p', 'rb'))
mmseq_cluster2uniprots = defaultdict(set)
for k,v in mmseqs_clusters.items():
    mmseq_cluster2uniprots[v].add(k)

In [None]:
len(mmseqs_clusters), len(mmseq_cluster2uniprots)

In [122]:
train_mmseqs_uniprots = set()
for uni in train_uniprots:
    train_mmseqs_uniprots.update( mmseq_cluster2uniprots[mmseqs_clusters[uni]] )

In [123]:
mmseqs_hiddens = torch.vstack([h for h,u in zip(screen_hiddens, screen_unis) if not (u in train_mmseqs_uniprots)])
mmseqs_all_ec_uniprots_ = [u for u in screen_unis if not (u in train_mmseqs_uniprots)]

In [None]:
sim_matrix = (all_substrate_features @ mmseqs_hiddens.T)

mmseqs_samples = []
percentiles = []
reaction_set = set()
metrics = defaultdict(list)

for idx, ref_sample in tqdm(enumerate(data_samples), ncols=100, total=len(data_samples)):

    if ref_sample['reaction_string'] in reaction_set:
        continue
    
    if  train_dataset.ec2uniprot.get( ref_sample['ec'], None) is None:
        continue

    labels = [ int(u in reaction2unis[ ref_sample['reaction_string'] ]) for u in mmseqs_all_ec_uniprots_]
    if not any(labels):
        continue 
    
    reaction_set.add(ref_sample['reaction_string'])

    mmseqs_samples.append(ref_sample)
    
    scores = sim_matrix[idx].squeeze().tolist()

    ref_score = max( scores[i] for i,l in enumerate(labels) if l)
    
    percentiles.append( percentileofscore(scores, ref_score ) )
    
    metrics["bedroc85"].append( bedroc_score(np.array(labels), np.array(scores), decreasing=True, alpha=85.0) )
    metrics["bedroc20"].append( bedroc_score(np.array(labels), np.array(scores), decreasing=True, alpha=20.0) )
    metrics["ef0.1"].append( enrichment_score(np.array(labels), np.array(scores), decreasing=True, chi=0.1) )
    metrics["ef0.05"].append( enrichment_score(np.array(labels), np.array(scores), decreasing=True, chi=0.05) )
    

In [131]:
sum([ d['uniprot_id'] in train_mmseqs_uniprots for d in mmseqs_samples]), len(mmseqs_samples)
label = [ d['uniprot_id'] in train_mmseqs_uniprots for d in mmseqs_samples]

In [None]:
print_metrics(metrics, percentiles, label)

### by protein families: foldseek

In [133]:
# foldseek clusters
foldseek_clusters = pickle.load(open('/home/EnzymeMap/foldseek/cov90_cluster.p', 'rb'))
foldseek_cluster2uniprots = defaultdict(set)
for k,v in foldseek_clusters.items():
    foldseek_cluster2uniprots[v].add(k)

In [134]:
train_foldseek_uniprots = set()
for uni in train_uniprots:
    train_foldseek_uniprots.update( foldseek_cluster2uniprots[ foldseek_clusters[uni]] )

In [135]:
foldseek_hiddens = torch.vstack([h for h,u in zip(screen_hiddens, screen_unis) if not (u in train_foldseek_uniprots)])
foldseek_all_ec_uniprots_ = [u for u in screen_unis if not (u in train_foldseek_uniprots)]

In [136]:
sim_matrix = (all_substrate_features @ foldseek_hiddens.T)

In [None]:
percentiles = []
data_samples_foldseek = []
reaction_set = set()
metrics = defaultdict(list)

for idx, ref_sample in tqdm(enumerate(data_samples), ncols=100, total=len(data_samples)):

    if ref_sample['reaction_string'] in reaction_set:
        continue
    
    if  train_dataset.ec2uniprot.get( ref_sample['ec'], None) is None:
        continue

    labels = [ int(u in reaction2unis[ ref_sample['reaction_string'] ]) for u in foldseek_all_ec_uniprots_]
    if not any(labels):
        continue 

    reaction_set.add(ref_sample['reaction_string'])
    data_samples_foldseek.append(ref_sample)
    
    scores = sim_matrix[idx].squeeze().tolist()

    ref_score = max( scores[i] for i,l in enumerate(labels) if l)
    
    percentiles.append( percentileofscore(scores, ref_score ) )
    
    metrics["bedroc85"].append( bedroc_score(np.array(labels), np.array(scores), decreasing=True, alpha=85.0) )
    metrics["bedroc20"].append( bedroc_score(np.array(labels), np.array(scores), decreasing=True, alpha=20.0) )
    metrics["ef0.1"].append( enrichment_score(np.array(labels), np.array(scores), decreasing=True, chi=0.1) )
    metrics["ef0.05"].append( enrichment_score(np.array(labels), np.array(scores), decreasing=True, chi=0.05) )
    

In [138]:
sum([ d['uniprot_id'] in train_foldseek_uniprots for d in data_samples_foldseek]), len(data_samples_foldseek)
label = [ d['uniprot_id'] in train_foldseek_uniprots for d in data_samples_foldseek]

In [None]:
print_metrics(metrics, percentiles, label)