In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from rogi import RoughnessIndex

from argparse import Namespace
import os
import copy
import pickle
import numpy as np
import pandas as pd
import rdkit
from rdkit import Chem
import argparse
from tqdm.notebook import tqdm
from scipy.spatial.distance import squareform
from rdkit.Chem import AllChem
from sklearn.metrics import r2_score
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import adjusted_rand_score
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy import interpolate
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import cm
from collections import defaultdict

import yaml
import random
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import to_dense_batch
from torch.utils.data import Subset
from sklearn.metrics import roc_auc_score

from molmcl.finetune.loader import MoleculeDataset
from molmcl.finetune.model import GNNPredictor

from molmcl.splitters import scaffold_split, moleculeace_split
from molmcl.utils.draw import draw_mols, draw_mol_with_highlight
from molmcl.utils.moleculeace import moleculeace_similarity, get_fc

import sys
sys.path.append('../scripts')
from finetune import train as train_func
from finetune import eval as eval_func
from finetune import get_optimizer, set_seed, optimize_prompt_weight_ri

In [None]:
with open('../config/moleculeace/chembl.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
config['dataset']['data_name'] = 'CHEMBL237_Ki'
config['dataset']['data_dir'] = '../data/finetune/moleculeace'
config['dataset']['feat_type'] = 'basic'
config['optim']['scheduler'] = 'cos_anneal'
config['optim']['gradient_clip'] = 5
config['optim']['prompt_lr'] = 0.0005
config['optim']['pretrain_lr'] = 0.0005
config['optim']['finetune_lr'] = 0.0005
config['model']['temperature'] = 0.7
config['model']['layernorm'] = False
config['model']['normalize'] = False
config['model']['use_prompt'] = True
config['model']['emb_dim'] = 300
config['model']['num_layer'] = 5
config['model']['checkpoint'] = '../checkpoint/zinc-gnn_basic.pt'
config['model']['backbone'] = 'gnn'
config['model']['dropout_ratio'] = 0
config['model']['attn_dropout_ratio'] = 0
config

In [None]:
def get_dataloaders(config, few_shot_num=10, few_shot_ratio=0, few_shot_cliff=False, seed=42):
    dataset = MoleculeDataset(config['dataset']['data_dir'],
                              config['dataset']['data_name'],
                              config['dataset']['feat_type'])
    num_task = dataset.num_task
    print('Loading dataset of size {} with {} target.'.format(len(dataset), num_task))

    if 'CHEMBL' in config['dataset']['data_name']:    
        train_idx, val_idx, test_idx = moleculeace_split(dataset.smiles, dataset.labels, val_size=0.1, test_size=0.1)
    else:
        train_idx, val_idx, test_idx = scaffold_split(dataset.smiles, frac_valid=0.1, frac_test=0.1, balanced=False)

    # full:
    full_dataset = Subset(dataset, train_idx + val_idx + test_idx)
    full_loader = DataLoader(full_dataset, batch_size=32, shuffle=False)
    
    train_dataset, val_dataset, test_dataset = \
        Subset(dataset, train_idx), Subset(dataset, val_idx), Subset(dataset, test_idx)

    # train:
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    # val:
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    # tst:
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

    # extract train-cliff pairs:
    train_smiles = [dataset.smiles[i] for i in train_idx]
    train_labels = [dataset.labels[i] for i in train_idx]
    sim = moleculeace_similarity(train_smiles)
    fc = (get_fc(train_labels, in_log10=True) > 10).astype(int)
    cliffs = np.logical_and(sim == 1, fc == 1).astype(int)
    cliff_pair_inds = np.argwhere(cliffs)
    noncliffs = np.logical_and(sim == 1, fc == 0).astype(int)
    noncliff_pair_inds = np.argwhere(noncliffs)    
    
    # few_shot:
    if few_shot_num > 0 or few_shot_ratio > 0:
        if few_shot_ratio > 0:
            few_shot_num = int(len(train_idx) * few_shot_ratio)
            
        train_fps = np.array([np.array(get_fp(smi)) for smi in train_smiles])

        if few_shot_num < len(train_idx):
            kmeans = KMeans(n_clusters=few_shot_num).fit(train_fps)
            few_shot_idx = np.argmin(euclidean_distances(train_fps, kmeans.cluster_centers_), axis=0)
            few_shot_idx = set(few_shot_idx)
            few_shot_idx = [train_idx[i] for i in few_shot_idx]
        else:
            few_shot_idx = train_idx

        few_shot_dataset = Subset(dataset, few_shot_idx)
        few_shot_loader = DataLoader(few_shot_dataset, batch_size=len(few_shot_dataset), shuffle=True)
    else:
        few_shot_loader = None

    return {'dataset': dataset,
            'splits': [train_idx, val_idx, test_idx],
            'mmp_inds': [noncliff_pair_inds, cliff_pair_inds],
            'dataloader': [train_loader, val_loader, test_loader, full_loader, few_shot_loader]}

def sample_triplet(cliff_pair_inds, noncliff_pair_inds, seed=-1):
    if seed > -1:
        np.random.seed(seed)
    while True:
        c_idx = np.random.choice(np.arange(len(cliff_pair_inds)))
        c_pair = cliff_pair_inds[c_idx]
        nc_idx = np.argwhere(c_pair[0] == noncliff_pair_inds[:, 0])
        if len(nc_idx):
            nc_idx = np.random.choice(nc_idx.flatten())
            nc_pair = noncliff_pair_inds[nc_idx]
            break
    return c_pair, nc_pair

def get_reps_local(ori_model, loader, config, best_checkpoint=None, channel_idx=-1):
    model = copy.deepcopy(ori_model)
    if best_checkpoint is not None:
        model.load_state_dict(best_checkpoint)
    
    graph_rep_list, label_list, smiles_list = [], [], []
    model.eval()
    for batch in (loader):
        batch.to(config['device'])        
        with torch.no_grad():
            graph_reps = []
            if model.backbone == 'gps':
                h_g, node_repres = model.gnn(batch.x, batch.pe, batch.edge_index, batch.edge_attr, batch.batch)
            else:
                h_g, node_repres = model.gnn(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

            # map back to batched nodes for aggregation
            batch_x, batch_mask = to_dense_batch(node_repres, batch.batch)

            # conditional aggregation given the prompt_inds
            for i in range(len(model.prompt_token)):
                h_g, h_x, _ = model.aggrs[i](batch_x, batch_mask)
                if config['model']['normalize']:
                    h_g = F.normalize(h_g, dim=-1)
                graph_reps.append(h_g)

        graph_rep_list.append(torch.stack(graph_reps))
        label_list.append(batch.label.view(-1, model.num_tasks))
        smiles_list += batch.smi
        
    graph_reps = torch.concat(graph_rep_list, dim=1)
    if channel_idx == -1:
        prompt_weight = model.get_prompt_weight(act=model.act)
        graph_rep = torch.matmul(graph_reps.transpose(0, 2), prompt_weight).transpose(0, 1)
    else:
        graph_rep = graph_reps[channel_idx]
    graph_rep = graph_rep.detach().cpu()
    
    labels = torch.concat(label_list, dim=0).detach().cpu()

    return smiles_list, graph_rep, labels

get_fp = lambda key: AllChem.GetMorganFingerprintAsBitVect(
    Chem.MolFromSmiles(key), radius=2, nBits=512)

## Visualization (0 $\to$ 100 epochs):
The experiment of representation space probe in Figure 3. 

For fair comparison, the model in this experiment is trained using the basic feature set (same as MolCLR and GraphLoG), as well as the same pre-training dataset (ZINC), model architecture (GIN) and hyperparameters (num_layer=5, emb_dim=300, ...). 

To perform the same analysis on MolCLR and GraphLoG, one can download the checkpoint from their github repositories ([MolCLR](https://github.com/yuyangw/MolCLR) and [GraphLoG](https://github.com/DeepGraphLearning/GraphLoG/tree/main)), and only load the checkpoint for GNN module. Make sure to set `use_prompt=False`. 

In [None]:
ri_metric = 'euclidean'  # euclidean, cosine
best_initialization = None

# Get Dataloader:
outputs = get_dataloaders(config, few_shot_num=0, few_shot_ratio=0, few_shot_cliff=False)
dataset = outputs['dataset']
train_idx, val_idx, test_idx = outputs['splits']
noncliff_pair_inds, cliff_pair_inds = outputs['mmp_inds']
train_loader, val_loader, test_loader, full_loader, few_shot_loader = outputs['dataloader']
print('Number of cliff pairs:', len(cliff_pair_inds))

# Build Model and Load Checkpoint:
model = GNNPredictor(num_layer=config['model']['num_layer'],
                     emb_dim=config['model']['emb_dim'],
                     num_tasks=dataset.num_task,
                     normalize=config['model']['normalize'],
                     atom_feat_dim=None,  # for basic_feature
                     bond_feat_dim=None,  # for basic_feature
                     drop_ratio=config['model']['dropout_ratio'],
                     attn_drop_ratio=config['model']['attn_dropout_ratio'],
                     temperature=config['model']['temperature'],
                     use_prompt=config['model']['use_prompt'],
                     model_head=config['model']['heads'],
                     layer_norm_out=config['model']['layernorm'], 
                     backbone=config['model']['backbone'])

if config['model']['checkpoint']:
    print('Loading checkpoint from {}'.format(config['model']['checkpoint']))
    model.load_state_dict(torch.load(config['model']['checkpoint'])['wrapper'], strict=True)

model.freeze_aggr_module()
model.to(config['device'])

set_seed(24)
# Initialize Prompt Weights:
if config['model']['use_prompt']:
    if best_initialization is None:
        best_initialization = \
            optimize_prompt_weight_ri(model, train_loader, None, config, act=model.act)

    model.set_prompt_weight(best_initialization.to(config['device']))
    initial_prompt_probs = model.get_prompt_weight(model.act).data.cpu()
    initial_prompt_weights = model.get_prompt_weight('none').data.cpu()

In [None]:
# run fine-tuning and retrieve representation:
method = 'composite'
setting2stats = {}
r2_history = []
smiles_for_plot, scores_history = [], []
criterion = nn.MSELoss(reduction='none')
optimizer = get_optimizer(model, config['optim'])
best_score, best_checkpoint = -float('inf'), None

set_seed(42)
for ep in range(101):
    if ep in [0, 10, 20, 50, 80, 100]:
        smiles_list, graph_reps, labels = get_reps_local(model, full_loader, config, best_checkpoint)

        setting2stats['{}-{}'.format(method, ep)] = (smiles_list, graph_reps, labels)
        
        X = graph_reps.numpy()[:len(train_idx)]
        Y = labels.numpy()[:, 0][:len(train_idx)]
        rogi = RoughnessIndex(Y=Y, X=X, metric=ri_metric, verbose=False)
        ri = rogi.compute_index()
        print('Roughness Index (ep{}): {}'.format(ep, ri))

        for batch in val_loader:
            with torch.no_grad():
                _, scores = model.get_representations(batch.to(config['device']), return_score=True)
            break
        scores_history.append(scores)
        if not smiles_for_plot:
            smiles_for_plot = batch.smi
        
        if ep == 100:
            break

    loss, _ = train_func(model, train_loader, criterion, optimizer, None, config)
    
    train_score = eval_func(model, train_loader, config, metric='r2')
    val_score = eval_func(model, val_loader, config, metric='r2')
    test_score = eval_func(model, test_loader, config, metric='r2')
    r2_history.append((train_score, val_score, test_score))
    
    if -loss > best_score:
        best_score = -loss
        best_checkpoint = copy.deepcopy(model.state_dict())
    
    if ep % 10 == 0:
        print('[epoch {}] Train R2: {} Val R2: {}'.format(ep, train_score, val_score))

#### Plot the training curve:

In [None]:
plt.figure(figsize=(5, 2), tight_layout=True)
cmap = cm.get_cmap("Set1")
train_r2 = [t[0] for t in r2_history]
val_r2 = [t[1] for t in r2_history]

plt.plot(train_r2, color=cmap(0), label='{}-{}'.format(method, 'train'))
plt.plot(val_r2, '--', color=cmap(0), label='{}-{}'.format(method, 'val'))
plt.show()

#### Visualization of representation space:

In [None]:
# compute roughness index
perplexity = 25
n_clusters = 10
cliff_distances_dict = defaultdict(list)
tsne_seeds = [[42, 43, 44, 44, 44], [42, 43, 48, 48, 46]]

method = 'composite'

fig, axes = plt.subplots(2, 5, figsize=(25, 6))
((axt1, axt2, axt3, axt4, axt5), (axv1, axv2, axv3, axv4, axv5)) = axes
for ei, ep in enumerate([0, 10, 20, 50, 100]):
    ei += 1
    key = '{}-{}'.format(method, ep)
    print(key)

    X = setting2stats[key][1].numpy()
    Y = setting2stats[key][2].numpy()
    X_train, X_val = X[:len(train_idx)], X[len(train_idx):len(train_idx)+len(val_idx)]
    Y_train, Y_val = Y[:len(train_idx)], Y[len(train_idx):len(train_idx)+len(val_idx)]

    # 2d-plot:
    ## Subplot-1:
    
    rogi = RoughnessIndex(Y=Y_train, X=X_train, metric=ri_metric, verbose=False)
    ri = rogi.compute_index()
    print('Roughness Index:', ri)

    ### compute fingerprint clustering:
    X_smi = setting2stats[key][0][:len(train_idx)]
    X_fps = np.array([np.array(get_fp(smi)) for smi in X_smi])
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X_fps)
    kmeans_label = kmeans.labels_
    kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X_train)
    kmeans_label2 = kmeans.labels_

    ### dimension reduction:
    tsne = TSNE(n_components=2, learning_rate='auto', init='random',
                metric='precomputed', random_state=tsne_seeds[0][ei-1], perplexity=perplexity)
    X_2d = tsne.fit_transform(squareform(rogi._Dx))

    ### main plot:
    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
      
    scatter1 = eval('axt{}'.format(ei)).scatter(
        X_2d[:, 0], X_2d[:, 1], c=rogi._Y, s=40, cmap=cm.get_cmap("coolwarm"), 
        alpha=0.8, zorder=10, edgecolors='grey', linewidths=0.4)
        
    for ci in set(kmeans_label):
        points = X_2d[kmeans_label == ci]
        hull = ConvexHull(points)
        x_hull = np.append(points[hull.vertices,0],
                           points[hull.vertices,0][0])
        y_hull = np.append(points[hull.vertices,1],
                           points[hull.vertices,1][0])
        # interpolate
        dist = np.sqrt((x_hull[:-1] - x_hull[1:])**2 + (y_hull[:-1] - y_hull[1:])**2)
        dist_along = np.concatenate(([0], dist.cumsum()))
        spline, u = interpolate.splprep([x_hull, y_hull], u=dist_along, s=0, per=1)
        interp_d = np.linspace(dist_along[0], dist_along[-1], 50)
        interp_x, interp_y = interpolate.splev(interp_d, spline)
        # plot shape
        eval('axt{}'.format(ei)).fill(interp_x, interp_y, '--', c='lightgrey', 
                 edgecolor='grey', linewidth=1, alpha=0.3, zorder=9)
        
    plot_text = 'ROGI: {}\nRand Index: {}'.format(
        round(ri, 3),
        round(adjusted_rand_score(kmeans_label2, kmeans_label), 3))  
    eval('axt{}'.format(ei)).text(0.03, 0.05, plot_text, fontsize=12,
        transform=eval('axt{}'.format(ei)).transAxes, zorder=20,
        bbox=dict(facecolor='white', edgecolor='grey', boxstyle='round', alpha=0.7))
    eval('axt{}'.format(ei)).tick_params(
        left=False, right=False, labelleft=False, labelbottom=False, bottom=False)

    ## Subplot-2:
    rogi = RoughnessIndex(Y=Y_val, X=X_val, metric=ri_metric, verbose=False)
    ri = rogi.compute_index()
    
    ### dimension reduction:
    tsne = TSNE(n_components=2, learning_rate='auto', init='random',
                metric='precomputed', random_state=tsne_seeds[1][ei-1], perplexity=perplexity)
    X_2d = tsne.fit_transform(squareform(rogi._Dx))

    # find test cliffs:
    val_smiles = [dataset.smiles[i] for i in val_idx]
    val_labels = [dataset.labels[i] for i in val_idx]
    sim = moleculeace_similarity(val_smiles)
    fc = (get_fc(val_labels, in_log10=True) > 10).astype(int)
    cliffs = np.logical_and(sim == 1, fc == 1).astype(int)
    cliff_pair_inds = np.unique(np.argwhere(cliffs), axis=0)
    cliff_distances = []
    for idx_1, idx_2 in cliff_pair_inds:
        cliff_distances.append(squareform(rogi._Dx)[idx_1, idx_2])
    noncliffs = np.logical_and(sim == 1, fc == 0).astype(int)
    noncliff_pair_inds = np.unique(np.argwhere(noncliffs), axis=0)
    noncliff_distances = []
    for idx_1, idx_2 in noncliff_pair_inds:
        noncliff_distances.append(squareform(rogi._Dx)[idx_1, idx_2])
    print('Average cliff pair distance:', np.mean(cliff_distances))
    print('Average noncliff pair distance:', np.mean(noncliff_distances))
    print('Cliff-noncliff distance ratio:', np.mean(cliff_distances) / np.mean(noncliff_distances))
    print('----------')
    c_pair, nc_pair = sample_triplet(cliff_pair_inds, noncliff_pair_inds, seed=12)
    
    # n_clusters = len(set(kmeans_label))
    # cmap = cm.get_cmap("rainbow_r", n_clusters)
    cmap = cm.get_cmap("coolwarm")
    scatter2 = eval('axv{}'.format(ei)).scatter(X_2d[:, 0], X_2d[:, 1], c=rogi._Y, s=40, 
                          cmap=cmap, alpha=0.8, zorder=10, edgecolors='grey', linewidths=0.4)
    
    eval('axv{}'.format(ei)).arrow(X_2d[c_pair[0], 0], X_2d[c_pair[0], 1], 
              X_2d[c_pair[1], 0]-X_2d[c_pair[0], 0], X_2d[c_pair[1], 1]-X_2d[c_pair[0], 1],
              color='red', alpha=0.7, zorder=12, width=0.3, head_width=1.3)
    eval('axv{}'.format(ei)).arrow(X_2d[nc_pair[0], 0], X_2d[nc_pair[0], 1], 
              X_2d[nc_pair[1], 0]-X_2d[nc_pair[0], 0], X_2d[nc_pair[1], 1]-X_2d[nc_pair[0], 1],
              color='blue', alpha=0.7, zorder=12, width=0.3, head_width=1.2)
    if ep == 0:
        r2 = 'N/A'
    else:
        r2 = round(r2_history[ep-1][1], 3)  
    dist_ratio = np.mean(cliff_distances) / np.mean(noncliff_distances)
    plot_text = 'R-Squared: {}\nCliff-noncliff Distance Ratio: {}'.format(r2, round(dist_ratio, 3))
    eval('axv{}'.format(ei)).text(0.03, 0.05, plot_text, fontsize=12,
        transform=eval('axv{}'.format(ei)).transAxes, zorder=12,
        bbox=dict(facecolor='white', edgecolor='grey', boxstyle='round', alpha=0.7))

    eval('axv{}'.format(ei)).tick_params(
        left=False, right=False, labelleft=False, labelbottom=False, bottom=False)
    
plt.tight_layout()
plt.colorbar(scatter2, ax=axes.ravel().tolist())        
plt.show()