In [1]:
DATASET_ORDER = ["TYK2", "USP7", "D2R", "Mpro"]
MODEL_ORDER = ["GP", "CP"]

In [2]:
import warnings
import pandas as pd
import numpy as np
warnings.filterwarnings("ignore")

from exs.ale.featurisers import Featuriser
from typing import Callable
from oekit.io import standardise_smiles

def load_dataset(name: str, path: str, smiles_col: str, affinity_col: str, affinity_conversion: Callable) -> pd.DataFrame:
    # load data
    data = pd.read_csv(path)

    # convert affinity to pXC50 scale
    data['affinity'] = affinity_conversion(data[affinity_col])

    # standardise column names and deduplicate
    data = data.rename(columns={smiles_col: 'SMILES'})
    data["SMILES"] = data.SMILES.apply(standardise_smiles)

    # Find duplicates and replace with mean
    data = data.groupby('SMILES').agg({'affinity': 'mean'}).reset_index()
    data["target"] = name

    # Calculate top 2% active compounds
    number_top_2p = round(len(data) * 0.02)
    top_2p = data.sort_values(by='affinity', ascending=False)[:number_top_2p].index   
    data['top_2p'] = False
    data.loc[top_2p, 'top_2p'] = True

    # do the same for top 5%
    number_top_5p = round(len(data) * 0.05)
    top_5p = data.sort_values(by='affinity', ascending=False)[:number_top_5p].index
    data["top_5p"] = False
    data.loc[top_5p, 'top_5p'] = True

    # pre-featurise
    f = Featuriser(smiles_col="SMILES", presets="ECFP8")
    data["fps"] = f.featurise(data)

    return data

In [2]:
import warnings
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from typing import Callable
warnings.filterwarnings("ignore")

def standardise_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Chem.MolToSmiles(mol)
    else:
        return None

def load_dataset(name: str, path: str, smiles_col: str, affinity_col: str, affinity_conversion: Callable) -> pd.DataFrame:
    # Load data
    data = pd.read_csv(path)

    # Convert affinity to pXC50 scale
    data['affinity'] = affinity_conversion(data[affinity_col])

    # Standardise column names and deduplicate
    data = data.rename(columns={smiles_col: 'SMILES'})
    data["SMILES"] = data.SMILES.apply(standardise_smiles)

    # Remove None values after standardisation
    data = data.dropna(subset=["SMILES"])

    # Find duplicates and replace with mean
    data = data.groupby('SMILES').agg({'affinity': 'mean'}).reset_index()
    data["target"] = name

    # Calculate top 2% and 5% active compounds
    data['top_2p'] = False
    data.loc[data['affinity'].nlargest(round(len(data) * 0.02)).index, 'top_2p'] = True
    data["top_5p"] = False
    data.loc[data['affinity'].nlargest(round(len(data) * 0.05)).index, 'top_5p'] = True

    # Featurise using RDKit (ECFP8 with radius=4)
    def get_ecfp8(smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=4, nBits=4096)
            return np.array(ecfp)
        else:
            return np.nan

    data["fps"] = data.SMILES.apply(get_ecfp8)

    return data


In [3]:
from exs.ale.models import GPModel, GenericModel,ChempropModel
from exs.mtl_bundle.models.chemprop import model_config as chemprop_config

def get_model(type: str, df_train: pd.DataFrame) -> GenericModel:
    if type == "GP":
        return GPModel(X_col='fps', y_col='affinity')
    elif type == "CP":  # Chemprop
               
        model_args = {
                    'model_name': "ale_chemprop",
                    'fit_args': chemprop_config,
                    'save_to_s3': False,
                    'batch_size': 50,
                    'num_epochs': 50,
                    'eval_metric': "r2",
                    'init_lr': 0.0001,
                    'max_lr': 0.001,
                    'final_lr': 0.0001,
                    'warmup_epochs': 5,
                    'encoder_from_pretrained_path': os.path.join(wd, "pretrained/chemprop_model.bin"),
                    'freeze_encoder': False,
                    'df_val': df_train.copy(),
                    'silence': True,
                    'num_epochs': 500,
                    'warmup_epochs': 10,
                }

        return ChempropModel(X_col='SMILES', y_col='affinity', **model_args)
#         raise Exception("Chemprop not defined yet")
    else:
        raise Exception(f"Unknown model type: {type}")

In [4]:
from exs.ale.selectors import RandomSelector, UCBSelector

def get_selector(type: str, seed: int) -> Callable:
    if type == "random":
        return RandomSelector(seed=seed)
    elif type == "explore":
        return UCBSelector(pred_col="pred", std_col="std", alpha=0, beta=1)
    elif type == "exploit":
        return UCBSelector(pred_col="pred", std_col="std", alpha=1, beta=0)
    else:
        raise Exception(f"Unknown selector type: {type}")

In [5]:
from typing import Any, List
from exs.ale.models import NoModel
from exs.ale.pipelines import ALPipelineDev, MulticycleAnalysis

def selected_in_cycle(row):
    sets = row.set.tolist()
    try:
        return sets.index('selection')
    except:
        return 'None'

def active_learning(data: pd.DataFrame, selection_protocol: List[Any], model_type: str, seed: int):

    # set up initial train/pool
    train = pd.DataFrame(columns=data.columns)
    pool = data.copy()
    print(model_type)
    # Set up the pipeline
    if model_type == "CP":
        print('SMILES')
        al_featuriser = Featuriser(fps_col="SMILES")
    else:
        al_featuriser = Featuriser(fps_col="fps")
    
    results = []
    
    # Run AL
    for cycle, (selector_type, batch_size) in enumerate(selection_protocol):

        if cycle == 0:
            al_model = NoModel(X_col="fps", y_col="affinity")
        else:
            al_model = get_model(model_type, train)

        al_selector = get_selector(selector_type, seed)    
        al_pipeline = ALPipelineDev(al_model, al_selector, al_featuriser, batch_size)
        al_pipeline.set_data(train, pool) 
        
        al_pipeline.run()

        combined_df = al_pipeline.outputs["combined_df"].copy()
        combined_df.drop(columns=["index"], inplace=True)
        results.append(combined_df)

        train = al_pipeline.outputs["train_new"].copy()
        pool = al_pipeline.outputs["pool_new"].copy()

    # Make a df with all cycles combined
    for cycle, res in enumerate(results):
        results[cycle]['cycle'] = cycle

    # run multicycle analysis
    multi_ana = MulticycleAnalysis(label_col="affinity", clean_data=False, orion_compat=False)
    multi_ana.set_data(results)
    multi_ana.run()

    all_data = pd.concat(results)
    index=["SMILES", "affinity", "top_2p", "top_5p"]
    data_pivot = all_data.pivot(index=index, columns=["cycle"], values=["pred", "std", "set"]).reset_index()
    data_pivot["selected_in_cycle"] = data_pivot.apply(selected_in_cycle, axis=1)
    
    return data_pivot, multi_ana.outputs


In [6]:
# number of top 2% cpds in a randomly selected batch of num_acquired cpds
from typing import Any, List

def cycles_to_cpds(protocol: List[Any], cycle: str) -> int:
    try:
        return sum([p[1] for p in protocol[:cycle]])
    except:
        return None

def get_baseline(protocol: List[Any], cycle: str, percent: float):
    try:
        cpds_acquired = sum([p[1] for p in protocol[:cycle]])
        return 0.01 * percent * cpds_acquired
    except:
        return None
    
def normalise_recall(top_N: float, baseline_N: float, total_N: int):
    try:
        return (top_N - baseline_N) / (total_N - baseline_N)
    except:
        return None

In [7]:
from typing import Tuple

def tag_dataset(df: pd.DataFrame, dataset: str, model: str, protocol_name: str, seed: int) -> pd.DataFrame:
    df["Dataset"] = dataset
    df["Model"] = model
    df["Protocol"] = protocol_name
    df["seed"] = seed
    return df

def do_analysis(res_df: pd.DataFrame, ana: pd.DataFrame, dataset: str, model: str, protocol_name: str, 
                protocol: List[Any], seed: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
    
    # calculate top 2% and 5% recall
    sel_2 = res_df.groupby('selected_in_cycle').top_2p.sum().reset_index()
    sel_5 = res_df.groupby('selected_in_cycle').top_5p.sum().reset_index()
    recall = sel_2.merge(sel_5, on="selected_in_cycle")
    recall = tag_dataset(recall, dataset, model, protocol_name, seed)
    recall["compounds_acquired"] = recall.selected_in_cycle.apply(lambda x: cycles_to_cpds(protocol, x))
    recall["top_2p_cum"] = np.cumsum(recall.top_2p)
    recall["top_5p_cum"] = np.cumsum(recall.top_5p)
    recall["baseline_2p"] = recall.selected_in_cycle.apply(lambda x: get_baseline(protocol, x, 2))
    recall["baseline_5p"] = recall.selected_in_cycle.apply(lambda x: get_baseline(protocol, x, 5))
    recall["normalised_2p"] = recall.apply(lambda row: normalise_recall(row.top_2p_cum, row.baseline_2p, res_df.top_2p.sum()), axis=1)
    recall["normalised_5p"] = recall.apply(lambda row: normalise_recall(row.top_5p_cum, row.baseline_5p, res_df.top_5p.sum()), axis=1)

    # metrics
    metrics = ana["retrospective_metrics_metrics_df"].reset_index()
    metrics = tag_dataset(metrics, dataset, model, protocol_name, seed)
    metrics["compounds_acquired"] = metrics.cycle.apply(lambda x: cycles_to_cpds(protocol, x))

    return recall, metrics

In [8]:
from exs.ale.analysis.plot_chemical_space import plot_chemical_space
import pickle

def plot_cpds_found_on_fmap(results: pd.DataFrame, dataset: str, set = "top_2p", model=None, protocol_name=None, noise_level=None, ax=None, xlabel=None, ylabel=None, legend=False, cycle=None):

    df = results[results.Dataset == dataset].reset_index(drop=True)

    if model is not None:
        df = df[df.Model == model].reset_index(drop=True)

    if protocol_name is not None:
        df = df[df.Protocol == protocol_name].reset_index(drop=True)

    if noise_level is not None:
        df = df[df.noise_level == noise_level].reset_index(drop=True)

    df = df.dropna(axis="columns")

    df["selected"] = 0
    if cycle is None:
        n_cycles = len(df.set.columns)-1
    else: 
        n_cycles = cycle
    df.loc[df[("set", n_cycles)] == "train", "selected"] = 1

    df = df.groupby(["SMILES"]).mean().reset_index()
    df = df.sort_values(by="SMILES").reset_index(drop=True)
    df["selected"] = df.selected.apply(lambda x: round(x, 2))

    df.loc[df.top_2p == False, "selected"]  = -1

    with open(f"data/feature_maps/{dataset}.pkl", "rb") as fmap_file:
        feature_map = pickle.load(fmap_file)

    levels = len(df.selected.unique()) - 1
    zorder_sel ={k/(levels-1)+1: k for k in range(levels)}
    zorder = {"not in top 2%": 0, **zorder_sel}

    plot_chemical_space(df, feature_map, set_col="selected",  \
                    set_colors={-1: "gray"}, zorder=zorder, cmap="copper_r", markersize=10, ax=ax, labels={-1: "not in top 2%"}, set_markersize=10, xlabel=xlabel, ylabel=ylabel, legend=legend)

In [9]:
#import matplotlib.pyplot as plt
#plt.rcParams['text.usetex'] = True

In [10]:
from typing import Dict

def make_recall_for_plot(recall: pd.DataFrame, total_2p: Dict, total_5p: Dict) -> pd.DataFrame:

    recall_for_plot = recall[(recall.selected_in_cycle != "None") & (recall.Protocol != "random-explore")].reset_index()

    for percent, total_p in zip([2,5], [total_2p, total_5p]):
        recall_for_plot[f"F1 ({percent}%)"] = recall_for_plot.apply(lambda row: 2*row[f"top_{percent}p_cum"] / (row.compounds_acquired + total_p[row.Dataset]), axis=1)
        recall_for_plot[f"Recall ({percent}%)"] = recall_for_plot.apply(lambda row: row[f"top_{percent}p_cum"] / total_p[row.Dataset], axis=1)

        batch = np.linspace(60, 360, 11)

        for dataset, n_total in total_p.items():
            for acquired in batch:
                for model in recall_for_plot.Model.unique():
                    recall_for_plot.loc[len(recall_for_plot)] = {
                        "Dataset": dataset,
                        "Model": model,
                        "Protocol": "Baseline",
                        "noise_level": "Baseline",
                        "compounds_acquired": acquired,
                        f"Recall ({percent}%)": 0.01 * percent * acquired / n_total,
                        f"F1 ({percent}%)": 0.01 * percent * acquired / (acquired + n_total)
                    }

    recall_for_plot.rename(columns={"compounds_acquired": "Compounds acquired"}, inplace=True)
    #recall_for_plot.rename(columns={"f1_2p": r'$F_1$ (2%)', "compounds_acquired": r'$N_{acq}$'}, inplace=True)
    return recall_for_plot