# Feature engineering

> In this module, we develop tools to extract features from compounds, proteins, etc.

In [None]:
#| default_exp feature

In [None]:
#| hide
import sys
sys.path.append("/notebooks/katlas")
from nbdev.showdoc import *
%matplotlib inline

In [None]:
#| export
from katlas.core import Data
import seaborn as sns
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.ML.Descriptors import MoleculeDescriptors
import pandas as pd
from rdkit.Chem import Draw
from rdkit.Chem import Descriptors
from sklearn.preprocessing import StandardScaler

from fastbook import *
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.wrap import enable_wrap, wrap
import esm
from tqdm.notebook import tqdm; tqdm.pandas()
import gc

## Features from amino acid

### RDKit descriptors

In [None]:
#| export
def smi2prop(df, # df needs to have SMILES an ID columns
             smi_colname = "SMILES", # column name of smiles
             id_colname = "ID", # column name of ID
             remove_duplicate=False, # remove features that are highly correlated
             thr = 0.95, # threshold of Pearson correlation
             normalize = True, # normalize features using StandardScaler()
            ):
    "Extract ~209 features from smiles via rdkit.Chem.Descriptors, and remove duplicate features"
    
    mols = [Chem.MolFromSmiles(smi) for smi in df[smi_colname]]
    desc_names = [desc_name[0] for desc_name in Descriptors.descList]
    desc_calc = MoleculeDescriptors.MolecularDescriptorCalculator(desc_names)
    desc_values = [desc_calc.CalcDescriptors(mol) for mol in mols]
    feature_df = pd.DataFrame(np.stack(desc_values), index=df[id_colname],columns=desc_names)
    if remove_duplicate:
        # remove compound that has same value across features
        # feature_df = feature_df.loc[feature_df.std(axis=1) != 0] 
        print(f'number of {feature_df.shape[1]} features are detected')
        #femove features with zero std
        feature_std = feature_df.std()
        zero_std_features = np.where(feature_std == 0)[0]
        to_drop = feature_df.columns[zero_std_features]
        feature_df = feature_df.drop(columns=to_drop).copy()
        print(f'dropping {len(to_drop)} features, as they have zero std variance:{to_drop.tolist()}')
        
        corr_matrix = feature_df.corr().abs()
        # Select upper triangle of correlation matrix
        upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
        # Find index of feature columns with correlation greater than a threshold (e.g., 0.95)
        to_drop = [column for column in upper.columns if any(upper[column] > thr)]
        # Drop the highly correlated features 
        feature_df = feature_df.drop(to_drop, axis=1).copy()
        print(f'dropping {len(to_drop)} features, as they have Pearson corr > {thr}:{to_drop}')
        print(f'number of {feature_df.shape[1]} features are left')
        

    if normalize:
        scaler = StandardScaler()
        transformed = scaler.fit_transform(feature_df.iloc[:,1:])
        feature_df.iloc[:,1:] = transformed
        
    feature_df = feature_df.reset_index()
    return feature_df

In [None]:
show_doc(smi2prop)

---

### smi2prop

>      smi2prop (df, smi_colname='SMILES', id_colname='ID',
>                remove_duplicate=False, thr=0.95, normalize=True)

Extract ~209 features from smiles via rdkit.Chem.Descriptors, and remove duplicate features

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| df |  |  | df needs to have SMILES an ID columns |
| smi_colname | str | SMILES | column name of smiles |
| id_colname | str | ID | column name of ID |
| remove_duplicate | bool | False | remove features that are highly correlated |
| thr | float | 0.95 | threshold of Pearson correlation |
| normalize | bool | True | normalize features using StandardScaler() |

In [None]:
df = Data.get_aa_info()[['aa','SMILES']]

In [None]:
df_feature = smi2prop(df, id_colname='aa',remove_duplicate=True,thr=0.9)

number of 209 features are detected
dropping 67 features, as they have zero std variance:['NumRadicalElectrons', 'PEOE_VSA13', 'PEOE_VSA5', 'SMR_VSA8', 'SlogP_VSA10', 'SlogP_VSA7', 'SlogP_VSA9', 'EState_VSA11', 'NumAliphaticCarbocycles', 'NumSaturatedCarbocycles', 'fr_ArN', 'fr_Ar_COO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide', 'fr_allylic_oxid', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl', 'fr_azide', 'fr_azo', 'fr_barbitur', 'fr_benzodiazepine', 'fr_diazo', 'fr_dihydropyridine', 'fr_epoxide', 'fr_ester', 'fr_ether', 'fr_furan', 'fr_halogen', 'fr_hdrzine', 'fr_hdrzone', 'fr_imide', 'fr_isocyan', 'fr_isothiocyan', 'fr_ketone', 'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone', 'fr_methoxy', 'fr_morpholine', 'fr_nitrile', 'fr_nitro', 'fr_nitro_arom', 'fr_nitro_arom_nonortho', 'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_piperdine', 'fr_piperzine', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN'

In [None]:
df_feature

Unnamed: 0,aa,MaxAbsEStateIndex,MinAbsEStateIndex,MinEStateIndex,qed,MolWt,MinPartialCharge,MaxAbsPartialCharge,FpDensityMorgan1,FpDensityMorgan2,FpDensityMorgan3,BCUT2D_MWHI,BCUT2D_MWLOW,BCUT2D_CHGHI,BCUT2D_CHGLO,BCUT2D_LOGPHI,BCUT2D_LOGPLOW,BCUT2D_MRLOW,AvgIpc,BalabanJ,BertzCT,HallKierAlpha,Kappa2,Kappa3,PEOE_VSA1,PEOE_VSA10,PEOE_VSA11,PEOE_VSA12,PEOE_VSA14,PEOE_VSA2,PEOE_VSA3,PEOE_VSA4,PEOE_VSA6,PEOE_VSA7,PEOE_VSA8,PEOE_VSA9,SMR_VSA10,SMR_VSA3,SMR_VSA4,SMR_VSA5,SMR_VSA6,SMR_VSA7,SlogP_VSA1,SlogP_VSA2,SlogP_VSA3,SlogP_VSA4,SlogP_VSA8,EState_VSA1,EState_VSA10,EState_VSA2,EState_VSA3,EState_VSA4,EState_VSA6,EState_VSA7,EState_VSA8,EState_VSA9,VSA_EState3,VSA_EState4,VSA_EState5,VSA_EState7,VSA_EState8,FractionCSP3,NumAliphaticHeterocycles,NumAromaticHeterocycles,NumHAcceptors,MolLogP,fr_Al_COO,fr_Al_OH,fr_C_O,fr_NH0,fr_NH1,fr_sulfide,fr_unbrch_alkane
0,A,9.574074,1.193554,0.43059,-0.375462,-1.43977,0.228166,-0.308911,1.656688,0.103007,-0.780051,-0.526903,1.203883,-1.415519,1.465101,-1.347549,1.162418,0.678178,-1.660842,0.058766,-0.928528,0.661861,-1.704924,-1.164169,-1.088274,-0.113045,-0.213201,-0.308607,-0.521596,-0.682582,-0.458413,-0.308393,-0.647398,-0.22426,-0.480351,-0.694405,-0.882977,-0.432331,-0.495561,-0.290532,-0.695145,-0.511968,-0.482124,-1.373429,-0.884527,-0.45843,-0.213201,-0.535127,-0.66116,-0.92285,-0.621218,-0.611041,-0.440926,-0.455591,-0.769976,-0.626017,-0.748232,-0.483822,0.342808,-0.4624,0.57219,0.349099,-0.213201,-0.308607,-1.192079,-0.082356,-0.308607,-0.308607,-0.458831,-0.213201,-0.428746,-0.213201,-0.308607
1,C,9.756435,-0.594121,0.396312,-0.623715,-0.643882,0.233514,-0.311825,1.656688,1.401076,0.150357,1.981889,1.146308,-0.628535,0.696026,-0.795604,0.07706,0.765682,-0.684106,0.233728,-0.790419,1.46066,-0.601392,-0.581437,-1.088274,-0.113045,-0.213201,-0.308607,-0.521596,-0.682582,-0.458413,3.364183,-0.647398,-1.107502,0.556864,-0.694405,2.01585,-0.432331,-0.495561,-1.340339,0.914158,-0.511968,-0.482124,-0.257709,-0.884527,-0.45843,-0.213201,-0.535127,-0.66116,0.310554,-0.621218,-0.611041,-0.440926,-0.455591,2.765813,-0.626017,-0.709478,-0.437633,0.631861,-0.4624,-0.522169,0.349099,-0.213201,-0.308607,0.331133,-0.196341,-0.308607,-0.308607,-0.458831,-0.213201,-0.428746,-0.213201,-0.308607
2,D,9.846435,0.536545,0.158193,-0.369303,-0.347487,0.037668,-0.20511,-1.033743,-1.273733,-1.372129,-0.525782,0.115873,0.346261,0.398995,-0.516138,-0.084662,-3.545366,-0.722995,0.82287,-0.281516,-0.547747,-0.442263,0.20098,0.373629,-0.113045,-0.213201,-0.308607,1.50096,1.050911,-0.458413,-0.308393,-0.647398,-1.107502,-0.480351,1.110102,0.487224,-0.432331,-0.495561,-0.366786,-0.695145,-0.511968,-0.482124,0.77464,0.433147,-0.45843,-0.213201,1.741627,1.556719,-0.92285,-0.621218,-0.611041,-0.440926,-0.455591,-0.769976,0.487402,1.474811,-0.483206,-2.651958,-0.936726,-0.522169,-0.429755,-0.213201,-0.308607,0.331133,-0.772086,3.24037,-0.308607,2.179449,-0.213201,-0.428746,-0.213201,-0.308607
3,E,9.99388,-1.144174,0.26396,-0.05677,0.000656,0.028134,-0.199915,-1.248978,-1.218664,-0.661636,-0.526379,-0.271923,0.069844,0.109885,-0.172745,0.256753,-0.43718,-0.085041,0.534017,-0.174262,-0.547747,0.361706,0.487383,0.373629,-0.113045,-0.213201,-0.308607,1.50096,1.050911,-0.458413,-0.308393,-0.647398,-0.288416,0.677296,-0.694405,0.487224,-0.432331,-0.495561,0.606766,-0.695145,-0.511968,-0.482124,0.77464,0.433147,-0.45843,-0.213201,0.561766,1.556719,1.830379,-0.621218,-0.611041,-0.440926,-0.455591,-0.769976,0.487402,1.540314,-0.410105,-2.063646,-0.66203,-0.522169,0.037557,-0.213201,-0.308607,0.331133,-0.278572,3.24037,-0.308607,2.179449,-0.213201,-0.428746,-0.213201,-0.308607
4,F,10.378642,0.050359,0.433525,1.825444,0.448946,0.231079,-0.310499,-1.168265,-0.723037,0.404105,-0.526883,-0.818551,-0.069808,-0.311742,0.19235,-0.048318,0.413831,0.72687,-1.313071,0.839266,-1.118318,0.39929,-0.66948,-1.088274,-0.113045,-0.213201,-0.308607,-0.521596,-0.682582,-0.458413,-0.308393,2.903823,0.421299,-0.480351,-0.694405,-0.882977,-0.432331,-0.495561,-0.366786,-0.695145,2.309694,-0.482124,-1.373429,0.880095,-0.45843,-0.213201,-0.535127,-0.66116,-0.92285,0.598568,0.379705,-0.440926,2.89905,-0.769976,-0.626017,-0.570587,0.165606,0.349769,-0.119319,-0.522169,-1.727845,-0.213201,-0.308607,-1.192079,1.464603,-0.308607,-0.308607,-0.458831,-0.213201,-0.428746,-0.213201,-0.308607
5,G,9.243056,-0.303815,0.426781,-0.653267,-1.787914,0.192416,-0.289431,1.656688,0.268215,-0.661636,-0.52712,3.378062,-3.610682,3.529379,-2.476702,3.820823,2.339301,-1.599718,-0.777419,-1.078909,0.661861,-1.750939,0.162382,-1.088274,-2.991634,-0.213201,-0.308607,-0.521596,-0.682582,-0.458413,-0.308393,-0.647398,-1.107502,-0.480351,1.144933,-0.882977,-0.432331,-0.495561,-2.256429,1.135685,-0.511968,-0.482124,-1.275893,-0.884527,-0.45843,-0.213201,-1.645349,-0.66116,0.480336,-0.621218,-0.611041,-0.440926,-0.455591,0.835328,-1.876177,-0.821329,-0.600834,0.333776,-0.4624,-0.736399,-0.429755,-0.213201,-0.308607,-1.192079,-0.573846,-0.308607,-0.308607,-0.458831,-0.213201,-0.428746,-0.213201,-0.308607
6,H,10.262535,-0.353472,0.394788,0.451496,0.199882,0.231208,-0.310569,1.216436,2.430949,2.50329,-0.526881,-0.208737,0.101905,0.110199,-0.082449,-0.077989,0.402951,1.880175,-1.488058,0.603424,-1.14114,-0.318264,-0.911087,0.338546,-0.113045,-0.213201,-0.308607,-0.521596,-0.682582,2.354812,-0.308393,-0.647398,-1.107502,1.794561,2.684045,-0.882977,3.441635,-0.495561,-0.366786,-0.695145,0.920122,-0.482124,0.559776,0.880095,-0.45843,-0.213201,-0.535127,-0.66116,0.453765,0.460478,-0.611041,0.989332,-0.455591,2.020838,-0.626017,-0.59794,-0.002637,0.257899,2.549303,-0.522169,-1.208609,-0.213201,3.24037,0.331133,-0.150798,-0.308607,-0.308607,-0.458831,4.690416,1.543487,-0.213201,-0.308607
7,I,10.17463,-0.983742,0.47199,0.845016,-0.39534,0.230896,-0.310399,0.580516,0.378355,-0.187973,-0.526868,-1.275497,0.67423,-1.166196,0.325566,-0.7782,-0.348891,-0.005828,0.993497,-0.542208,0.661861,-0.020725,-0.588014,-1.088274,-0.113045,-0.213201,-0.308607,-0.521596,-0.682582,-0.458413,-0.308393,1.725593,-0.352571,-0.480351,-0.694405,-0.882977,-0.432331,1.434128,1.732827,-0.695145,-0.511968,-0.482124,-1.373429,-0.884527,2.235426,-0.213201,-0.535127,-0.66116,0.345941,-0.621218,0.532387,-0.440926,1.075912,-0.769976,-0.626017,-0.614086,-0.290794,0.581006,0.261874,2.374724,1.127953,-0.213201,-0.308607,-1.192079,1.215885,-0.308607,-0.308607,-0.458831,-0.213201,-0.428746,-0.213201,-0.308607
8,K,10.137222,0.496819,0.454982,-0.321846,-0.022674,0.229538,-0.309659,-0.7647,-0.227411,0.404105,-0.526884,-0.507799,-0.278604,-0.144944,-0.009012,0.363246,0.280423,0.120964,-0.028294,-0.518273,0.57057,1.71859,0.827571,0.553168,-0.113045,-0.213201,-0.308607,-0.521596,-0.682582,-0.458413,-0.308393,0.104345,1.365567,-0.480351,-0.694405,-0.882977,-0.432331,1.374052,1.580319,1.135685,-0.511968,1.482744,-0.104126,-0.884527,-0.45843,-0.213201,-0.535127,-0.66116,-0.92285,1.841899,1.675814,-0.440926,-0.455591,-0.769976,0.624142,-0.62319,1.992128,0.400658,1.46567,-0.056724,1.127953,-0.213201,-0.308607,0.331133,0.055666,-0.308607,-0.308607,-0.458831,-0.213201,-0.428746,-0.213201,3.24037
9,L,10.109769,-0.040705,0.471495,0.845016,-0.39534,0.229561,-0.309671,0.04243,-0.172341,-0.582692,-0.526883,-0.970017,0.047138,-0.616358,0.207388,0.112309,0.157398,-0.722995,0.556356,-0.560003,0.661861,-0.020725,0.580397,-1.088274,-0.113045,-0.213201,-0.308607,-0.521596,-0.682582,-0.458413,-0.308393,0.97385,0.466515,-0.480351,-0.694405,-0.882977,-0.432331,1.434128,1.732827,-0.695145,-0.511968,-0.482124,-1.373429,-0.884527,2.235426,-0.213201,-0.535127,-0.66116,-0.92285,1.722814,-0.611041,-0.440926,1.075912,-0.769976,-0.626017,-0.628416,-0.31434,1.137342,0.028423,2.481339,1.127953,-0.213201,-0.308607,-1.192079,1.215885,-0.308607,-0.308607,-0.458831,-0.213201,-0.428746,-0.213201,-0.308607


In [None]:
#df_feature.to_csv('aa_feature.csv',index=False)

### Morgan fingerprint

In [None]:
#| export
def smi2morgan(df, # a dataframe contains ID and SMILES columns
               smi_colname = "SMILES", # set smiles columne name
               id_colname = "ID", # set ID column name
              ):
    "Like `smi2prop`, get 2048 morgan feature (0/1) given a dataframe that contains ID&smiles"
    mols = [Chem.MolFromSmiles(smi) for smi in df[smi_colname]]
    morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) for mol in mols]
    fp_df = pd.DataFrame(np.array(morgan_fps), index=df[id_colname])
    colnames = [f'morgan_{i}' for i in fp_df.columns]
    fp_df.columns = colnames
    fp_df = fp_df.reset_index()
    return fp_df

In [None]:
show_doc(smi2morgan)

---

### smi2morgan

>      smi2morgan (df, smi_colname='SMILES', id_colname='ID')

Like `smi2prop`, get 2048 morgan feature (0/1) given a dataframe that contains ID&smiles

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| df |  |  | a dataframe contains ID and SMILES columns |
| smi_colname | str | SMILES | set smiles columne name |
| id_colname | str | ID | set ID column name |

## Features from protein sequence

### ESM2

In [None]:
#| export
def esm_embeddings(df: pd.DataFrame, 
                   seq_colname: str, #The name of the column containing the sequences.
                   model_name: str = "esm2_t33_650M_UR50D", #The name of the ESM model to use for the embeddings.
                  ) -> pd.DataFrame:
    """
    Extract 1280 esmfold2 embeddings from protein sequence.
    """
    
    # Initialize distributed world with world_size 1
    if not torch.distributed.is_initialized():
        url = "tcp://localhost:23456"
        torch.distributed.init_process_group(backend="nccl", init_method=url, world_size=1, rank=0)
    
    #get number of repr layers
    match = re.search(r'_t(\d+)_', model_name)
    number = int(match.group(1))
    print(f"repr_layers number for model {model_name} is {number}.")
    print("You can also choose other esm2 models:",
          "\nesm2_t48_15B_UR50D\nesm2_t36_3B_UR50D\nesm2_t33_650M_UR50D\nesm2_t30_150M_UR50D\nesm2_t12_35M_UR50D\nesm2_t6_8M_UR50D\n")

    # Download model data from the hub
    model_data, regression_data = esm.pretrained._download_model_and_regression_data(model_name)

    # Initialize the model with FSDP wrapper
    fsdp_params = dict(
        mixed_precision=True,
        flatten_parameters=True,
        state_dict_device=torch.device("cpu"),  # reduce GPU mem usage
        cpu_offload=True,  # enable cpu offloading
    )

    with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
        model, vocab = esm.pretrained.load_model_and_alphabet_core(
            model_name, model_data, regression_data
        )
        batch_converter = vocab.get_batch_converter()
        model.eval()

        # Wrap each layer in FSDP separately
        for name, child in model.named_children():
            if name == "layers":
                for layer_name, layer in child.named_children():
                    wrapped_layer = wrap(layer)
                    setattr(child, layer_name, wrapped_layer)
        model = wrap(model)

        # Define the feature extraction function
        def get_feature(r, colname=seq_colname) -> np.ndarray:
            data = [('protein', r[colname])]
            labels, strs, tokens = batch_converter(data)
            with torch.no_grad():
                results = model(tokens.cuda(), repr_layers=[number], return_contacts=False)
            rpr = results["representations"][number].squeeze()
            rpr = rpr[1 : len(r[colname]) + 1].mean(0).detach().cpu().numpy()

            del results, labels, strs, tokens, data #especially need to delete those on cuda: tokens, results
            gc.collect()

            return rpr
        
        # Apply the feature extraction function to each row in the DataFrame
        series = df.progress_apply(get_feature, axis=1)
        df_feature = pd.DataFrame(series.tolist())

        return df_feature

In [None]:
show_doc(esm_embeddings)

---

### esm_embeddings

>      esm_embeddings (df:pandas.core.frame.DataFrame, seq_colname:str,
>                      model_name:str='esm2_t33_650M_UR50D')

Extract 1280 esmfold2 embeddings from protein sequence.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| df | DataFrame |  |  |
| seq_colname | str |  | The name of the column containing the sequences. |
| model_name | str | esm2_t33_650M_UR50D | The name of the ESM model to use for the embeddings. |
| **Returns** | **DataFrame** |  |  |

### T5

In [None]:
#| export
def T5_embeddings(sequence, device = 'cuda'):
    
    seq_len = len(sequence)
    
    # Prepare the protein sequences as a list
    sequence = [" ".join(list(re.sub(r"[UZOB]", "X", sequence)))]
    
    # Tokenize sequences and pad up to the longest sequence in the batch
    ids = tokenizer.batch_encode_plus(sequence, add_special_tokens=True, padding="longest")
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)
    
    # Generate embeddings
    with torch.no_grad():
        embedding_rpr = model(input_ids=input_ids, attention_mask=attention_mask)
    
    emb_mean = embedding_rpr.last_hidden_state[0][:seq_len].detach().cpu().numpy().mean(axis=0)
    
    return emb_mean

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
