In [186]:
import pandas as pd
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
import umap

from rdkit.Chem import MolFromSmiles, DataStructs
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect

In [55]:
w = torch.load('../analysis/DDC_KIBA_best_epoch_23.ckpt', map_location=torch.device('cpu'))

In [56]:
d_enc = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU())

In [57]:
d_enc.load_state_dict({k[10:]: v for k, v in w['state_dict'].items() if k.startswith('d_encoder')}, strict=False)

In [58]:
d_enc

In [99]:
kiba_ddc = pd.read_csv('../analysis/kiba_cliff_pairs_ta1_ts0.9_r_wt.csv')

In [84]:
kiba_ddc

In [87]:
kiba_ddc['target'].unique().sum()

In [89]:
target_pairs_count = kiba_ddc.groupby('target').apply(lambda x: x[['drug1', 'drug2']].apply(frozenset, axis=1).nunique())

In [90]:
target_pairs_count

In [109]:
t70 = kiba_ddc[kiba_ddc['target'] == 70]

In [110]:
t70['drug1'].value_counts()

In [115]:
t70[t70['drug1']=='CHEMBL1241487']['cliff'].value_counts()

In [119]:
t70_d1_d = t70[t70['drug1']=='CHEMBL1241487']

In [205]:
t70_d1_d

In [210]:
dr1 = pd.DataFrame({'drug1': ['CHEMBL1241487'], 
                    'drug2': ['CHEMBL1241487'], 
                    'cliff': [2], 
                    'target': [70], 
                    'smiles1': ['C1CCC(C1)N2C3=C(C(=N2)C4=CC(=C(C=C4)N)O)C(=NC=N3)N'], 
                    'smiles2': ['C1CCC(C1)N2C3=C(C(=N2)C4=CC(=C(C=C4)N)O)C(=NC=N3)N']})

t70_d1_d = pd.concat([t70_d1_d, dr1], axis=0)

In [213]:
smiles = t70_d1_d['smiles2'].tolist()

In [214]:
len(smiles)

In [217]:
features = []
for s in smiles:    
    mol = MolFromSmiles(s)
    if mol is not None:
        fp = GetMorganFingerprintAsBitVect(mol, 2, 1024)
        arr = np.zeros((0,), dtype=np.int8)
        DataStructs.ConvertToNumpyArray(fp, arr)
        features.append(torch.tensor(arr, dtype=torch.float32))
    else:
        print(s)

In [218]:
features

In [219]:
len(features)

In [221]:
out = d_enc(torch.stack(features)).detach().numpy()

In [222]:
out.shape

In [224]:
u = umap.UMAP()

In [225]:
umap_res = u.fit_transform(out)

In [226]:
plt.scatter(umap_res[:, 0], umap_res[:, 1], c = t70_d1_d['cliff'], cmap='viridis')
plt.title(' ')
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.colorbar(label='Cliff Value')
plt.show()