# Tutorial: Evaluating a Prototypical Network on FSL-CP 

Since there is no 'standard' way of doing meta/few-shot learning, models can have vastly different architectures and train very differently. Hence, one might have to change some details in this notebook to work with their own architecture. Nonetheless, this is a good starting point to familiarise yourself with the process of evaluating models on the dataset. 

In [1]:
# Setting things up for importing local modules
import os 
import sys

# Add folder FSL_CP to sys.path
FSL_CP_PATH = os.path.join(os.environ['HOME'], 'FSL_CP')
sys.path.insert(0, FSL_CP_PATH)

# Change working directory to FSL_CP
os.chdir(FSL_CP_PATH)


In [2]:
# Imports
from fsl_cp.utils.misc import sliding_average
from fsl_cp.utils.metrics import delta_auprc
from fsl_cp.datamodule.base import BaseDatasetCP, BaseSamplerCP

import torch 
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score, f1_score, cohen_kappa_score, roc_auc_score, accuracy_score

We evaluate every model over 100 test episodes take average of the scores. In addition, performances are recorded over different support set sizes (8, 16, 32, 64, 96).

In [3]:
# Initialisations
support_set_sizes = [8, 16, 32, 64, 96]
query_set_size = 32
num_episodes_val = 100
num_episodes_test = 100
log_update_freq = 10
val_freq = 100

# Hyperparameters to tune
num_episodes_train = 1000
meta_batch_size=3
step_size = 500

json_path = os.path.join(FSL_CP_PATH, 'data/output/data_split.json')
label_df_path= os.path.join(FSL_CP_PATH, 'data/output/FINAL_LABEL_DF.csv')
df_assay_id_map_path = os.path.join(FSL_CP_PATH, 'data/output/assay_target_map.csv') 
cp_f_path=[os.path.join(FSL_CP_PATH,'data/output/norm_CP_feature_df.csv')]

# What device to use? Change to 'cpu' if you don't have gpu
device = 'cuda:0'

# Output paths
result_summary_path1 = os.path.join(FSL_CP_PATH, "notebook/result_summary/protonet_cp_auroc_result_summary.csv") 
result_summary_path2 = os.path.join(FSL_CP_PATH, "notebook/result_summary/protonet_cp_dauprc_result_summary.csv") 
result_summary_path3 = os.path.join(FSL_CP_PATH, "notebook/result_summary/protonet_cp_bacc_result_summary.csv") 
result_summary_path4 = os.path.join(FSL_CP_PATH, "notebook/result_summary/protonet_cp_f1_result_summary.csv") 
result_summary_path5 = os.path.join(FSL_CP_PATH, "notebook/result_summary/protonet_cp_kappa_result_summary.csv") 

Information about which assays are in the train/validation/test set is stored in a JSON file.

In [4]:
# Load the assay keys.
with open(json_path) as f:
    data = json.load(f)
train_split = data['train']
val_split = data['val']
test_split = data['test']

Model performances are stored in Pandas dataframes and saved in csv files. We record performances according to 5 metrics: area under ROC curve, balanced accuray, f1 score, Cohen's Kappa and (delta) area under Precision-Recall curve.

In [5]:
# Final result dictionary.

final_result_auroc = {
    '8': [],
    '16': [],
    '32': [],
    '64': [],
    '96': []
}
final_result_dauprc = {
    '8': [],
    '16': [],
    '32': [],
    '64': [],
    '96': []
}
final_result_bacc = {
    '8': [],
    '16': [],
    '32': [],
    '64': [],
    '96': []
}
final_result_f1 = {
    '8': [],
    '16': [],
    '32': [],
    '64': [],
    '96': []
}
final_result_kappa = {
    '8': [],
    '16': [],
    '32': [],
    '64': [],
    '96': []
}

final_result_auroc['ASSAY_ID'] = test_split
final_result_dauprc['ASSAY_ID'] = test_split
final_result_bacc['ASSAY_ID'] = test_split
final_result_f1['ASSAY_ID'] = test_split
final_result_kappa['ASSAY_ID'] = test_split

For this example, we pick a simple Prototypical Network (Snell et al., 2017) using a 2-hidden-layer fully-connected neural network as backbone.

Every compound has an image and a feature representation, but for the sake of simplicity, we will focus on feature representation (or 'CP features' as it is called in the paper).

In [6]:
# Prototypical Network with a simple backbone 

class fnn(nn.Module):
    """Simple fully-connected neural network with 2 hidden layers.
    """

    def __init__(self, num_classes=512, input_shape=None):
        super(fnn, self).__init__()
        assert input_shape
        fc_units = 2048
        drop_prob = 0.5
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(input_shape, fc_units),
            nn.ReLU(inplace=True),
            nn.Dropout(p=drop_prob),
            nn.Linear(fc_units, fc_units),
            nn.ReLU(inplace=True),
            nn.Dropout(p=drop_prob),
            nn.Linear(fc_units, num_classes),
        )
        
        # init
        self.init_parameters()
    
    def init_parameters(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                nn.init.constant_(module.bias, 0)
    
    def forward(self, x):
        return self.classifier(x)
    

class ProtoNet(nn.Module):
    """ Prototypical Network with Euclidean distance.
    """
    def __init__(self, backbone: nn.Module):
        super(ProtoNet, self).__init__()
        self.backbone = backbone
    
    def forward(
        self,
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor
    ):
        """
        Predict query labels using support images with labels.
        """
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)
        #n_way = 2
        n_way = len(torch.unique(support_labels))
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels==label)].mean(0)
                for label in range(n_way)
            ]
        )

        dists = torch.cdist(z_query, z_proto)
        scores = -dists

        return scores

Below are some helper functions for training.

In [7]:
# Helper functions

def fit(
        support_images: torch.Tensor,
        support_labels: torch.Tensor,
        query_images: torch.Tensor,
        query_labels: torch.Tensor,
        criterion, 
        optimizer,
        model,
        device
    ) -> float:
    """(Meta-)Train a protonet model on support and query images and labels. 
    Return: Loss with the gradient calculated.
    """
    optimizer.zero_grad()
    classification_scores = model(
        support_images.to(device), support_labels.to(device), query_images.to(device)
    )
    s = nn.LogSoftmax(dim=1)
    log_p = s(classification_scores)
    loss = criterion(log_p, query_labels.to(device))
    loss.backward()
    #optimizer.step()

    return loss.item()


def evaluate_on_one_task(
    model,
    support_images: torch.Tensor,
    support_labels: torch.Tensor,
    query_images: torch.Tensor,
    query_labels: torch.Tensor,
    device
):
    """Returns the prediction of the protonet model and the real label."""
    s = nn.Softmax(dim=1)
    pred_float = s(model(support_images.to(device), support_labels.to(device), query_images.to(device))).detach().cpu().numpy()
    pred_float = [i[1] for i in pred_float]
    pred_round = (
    torch.max(
        model(support_images.to(device), support_labels.to(device), query_images.to(device))
        .detach()
        .data,
        1,
    )[1]).cpu().numpy()
    return pred_float, pred_round, query_labels.cpu().numpy()


def evaluate(model, data_loader: DataLoader, device):
    """ Evaluate the model on a DataLoader object.
    Return means and standard deviations of 5 metrics below."""
    AUROC_scores = []
    dAUPRC_scores = []
    bacc_scores = []
    F1_scores = []
    kappa_scores = []
    model.eval()
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in enumerate(data_loader):

            y_float, y_pred, y_true = evaluate_on_one_task(
                model, support_images, support_labels, query_images, query_labels, device
            )
            AUROC_score = roc_auc_score(y_true, y_float)
            dAUPRC_score = delta_auprc(y_true, y_float)
            bacc_score = balanced_accuracy_score(y_true, y_pred, adjusted=True)
            F1_score = f1_score(y_true, y_pred)
            kappa_score = cohen_kappa_score(y_true, y_pred)

            AUROC_scores.append(AUROC_score)
            dAUPRC_scores.append(dAUPRC_score)
            bacc_scores.append(bacc_score)
            F1_scores.append(F1_score)
            kappa_scores.append(kappa_score)

    return np.mean(AUROC_scores), np.std(AUROC_scores), np.mean(dAUPRC_scores), np.std(dAUPRC_scores), np.mean(bacc_scores), \
        np.std(bacc_scores), np.mean(F1_scores), np.std(F1_scores), np.mean(kappa_scores), np.std(kappa_scores)


def eval(model, data_loader: DataLoader, device):
    """ Evaluate the model on a DataLoader object.
    Only returns accuracy."""
    acc_scores = []
    with torch.no_grad():
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            class_ids,
        ) in enumerate(data_loader):

            y_float, y_pred, y_true = evaluate_on_one_task(
                model, support_images, support_labels, query_images, query_labels, device
            )
            acc_scores.append(accuracy_score(y_true, y_pred))
    return np.mean(acc_scores)
           

This is the main traning loop. It loops through all support set sizes, train the model then to few-shot prediction on the test assays. 

In [8]:
# Loop through all support set size, performing few-shot prediction.
for support_set_size in support_set_sizes:
    tqdm.write(f"Analysing for support set size {support_set_size}")

    # Load train data.
    train_data = BaseDatasetCP(
        train_split, 
        label_df_path= label_df_path, 
        cp_f_path=cp_f_path
    )
    train_sampler = BaseSamplerCP(
            task_dataset=train_data,
            support_set_size=support_set_size,
            query_set_size=query_set_size,
            num_episodes=num_episodes_train,
            meta_batch_size=meta_batch_size
    )
    train_loader = DataLoader(
            train_data,
            batch_sampler=train_sampler,
            num_workers=12,
            pin_memory=True,
            collate_fn=train_sampler.episodic_collate_fn,
    )

    # Load val data.
    val_data = BaseDatasetCP(
        val_split, 
        label_df_path= label_df_path, 
        cp_f_path=cp_f_path
    )
    val_sampler = BaseSamplerCP(
            task_dataset=val_data,
            support_set_size=support_set_size,
            query_set_size=query_set_size,
            num_episodes=num_episodes_val,
            meta_batch_size=meta_batch_size
    )
    val_loader = DataLoader(
            val_data,
            batch_sampler=val_sampler,
            num_workers=12,
            pin_memory=True,
            collate_fn=val_sampler.episodic_collate_fn,
    )

    # Load model.
    input_shape=len(train_data[3][0])
    backbone = fnn(num_classes=512, input_shape=input_shape) 
    model = ProtoNet(backbone).to(device)

    # Meta-pretraining the protonet.
    criterion = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = StepLR(opt, step_size=step_size, gamma=0.1)
    all_loss = []
    model.train()
    with tqdm(enumerate(train_loader), total=len(train_loader), leave=True) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            
            # Fit model to data
            loss_value = fit(support_images, support_labels, query_images, query_labels, criterion, opt, model, device) 

            # Optimizer steps after accumulating gradients from 'meta_batch_size' batches 
            if (episode_index+1) % meta_batch_size == 0:
                opt.step() 
                scheduler.step() # Learning rate decay after a fixed time 

            # Update the loss on the progress bar
            all_loss.append(loss_value)
            if episode_index % log_update_freq == 0:
                if episode_index % val_freq == 0:
                    val_acc = eval(model, val_loader, device)

                tqdm_train.set_postfix(val_acc=val_acc, train_loss=sliding_average(all_loss, log_update_freq), lr=scheduler.get_last_lr())
            

    # Perform inference on all test assays.
    for test_assay in tqdm(test_split):
        
        # Load test data.
        test_data = BaseDatasetCP(
            test_split, 
            label_df_path= label_df_path, 
            cp_f_path=cp_f_path
        )
        test_sampler = BaseSamplerCP(
            task_dataset=test_data,
            support_set_size=support_set_size,
            query_set_size=query_set_size,
            num_episodes=num_episodes_test,
            meta_batch_size=meta_batch_size
        )
        test_loader = DataLoader(
            test_data,
            batch_sampler=test_sampler,
            num_workers=12,
            pin_memory=True,
            collate_fn=test_sampler.episodic_collate_fn,
        )

        # Evaluate the performance of the model.
        auroc_mean, auroc_std, dauprc_mean, dauprc_std, bacc_mean, bacc_std, f1_mean, f1_std, kappa_mean, kappa_std = evaluate(
            model, 
            test_loader, 
            device
        )
        final_result_auroc[str(support_set_size)].append(f"{auroc_mean:.2f}+/-{auroc_std:.2f}")
        final_result_dauprc[str(support_set_size)].append(f"{dauprc_mean:.2f}+/-{dauprc_std:.2f}")
        final_result_bacc[str(support_set_size)].append(f"{bacc_mean:.2f}+/-{bacc_std:.2f}")
        final_result_f1[str(support_set_size)].append(f"{f1_mean:.2f}+/-{f1_std:.2f}")
        final_result_kappa[str(support_set_size)].append(f"{kappa_mean:.2f}+/-{kappa_std:.2f}")

Analysing for support set size 8


100%|██████████| 3000/3000 [02:30<00:00, 19.98it/s, lr=[1e-05], train_loss=1.24, val_acc=0.763]  
100%|██████████| 18/18 [02:46<00:00,  9.24s/it]


Analysing for support set size 16


100%|██████████| 3000/3000 [02:29<00:00, 20.04it/s, lr=[1e-05], train_loss=0.478, val_acc=0.79]  
100%|██████████| 18/18 [02:38<00:00,  8.79s/it]


Analysing for support set size 32


100%|██████████| 3000/3000 [02:34<00:00, 19.40it/s, lr=[1e-05], train_loss=0.51, val_acc=0.786]  
100%|██████████| 18/18 [02:43<00:00,  9.09s/it]


Analysing for support set size 64


100%|██████████| 3000/3000 [02:34<00:00, 19.45it/s, lr=[1e-05], train_loss=0.421, val_acc=0.789] 
100%|██████████| 18/18 [02:47<00:00,  9.28s/it]


Analysing for support set size 96


100%|██████████| 3000/3000 [02:34<00:00, 19.44it/s, lr=[1e-05], train_loss=0.391, val_acc=0.761] 
100%|██████████| 18/18 [02:50<00:00,  9.47s/it]


The above loop takes around 30 mins to train

Finally, save the dataframes to csv files in the correct format.

In [9]:
# Create and save result summary csv files.
df_assay_id_map = pd.read_csv(df_assay_id_map_path)
df_assay_id_map = df_assay_id_map.astype({'ASSAY_ID': str})

df_score = pd.DataFrame(data=final_result_auroc)
df_final = pd.merge(df_assay_id_map[['ASSAY_ID', 'assay_chembl_id']], df_score, on='ASSAY_ID', how='right')
df_final.to_csv(result_summary_path1, index=False)

df_score = pd.DataFrame(data=final_result_dauprc)
df_final = pd.merge(df_assay_id_map[['ASSAY_ID', 'assay_chembl_id']], df_score, on='ASSAY_ID', how='right')
df_final.to_csv(result_summary_path2, index=False)

df_score = pd.DataFrame(data=final_result_bacc)
df_final = pd.merge(df_assay_id_map[['ASSAY_ID', 'assay_chembl_id']], df_score, on='ASSAY_ID', how='right')
df_final.to_csv(result_summary_path3, index=False)

df_score = pd.DataFrame(data=final_result_f1)
df_final = pd.merge(df_assay_id_map[['ASSAY_ID', 'assay_chembl_id']], df_score, on='ASSAY_ID', how='right')
df_final.to_csv(result_summary_path4, index=False)

df_score = pd.DataFrame(data=final_result_kappa)
df_final = pd.merge(df_assay_id_map[['ASSAY_ID', 'assay_chembl_id']], df_score, on='ASSAY_ID', how='right')
df_final.to_csv(result_summary_path5, index=False)