### Install Required Libraries

In [None]:
# !pip install -r requirements.txt

### Install cudf library

Below is the method to install cudf on a colab notebook.

In [None]:
# !git clone https://github.com/rapidsai/rapidsai-csp-utils.git
# !python rapidsai-csp-utils/colab/pip-install.py

Check if cudf import works

In [None]:
import cudf
import cuml

 **Ensure datamodules.py is present in the local directory**

In [None]:
# Importing Libraries
import os, json
import gc
import re
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt
from os.path import join
from merlin.dataloader.torch import Loader
import torch.nn.init as init
import pickle
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle
import copy
from sklearn.metrics import classification_report, confusion_matrix


# Custom Libraries
from datamodules import merlin_dataset_factory, set_default_kwargs_dataset

def get_latest_run(folder_path):

    files = os.listdir(folder_path)
    model_iters = [int(re.search(r'model_(\d+)', file).group(1)) for file in files if re.match(r'model_(\d+)', file)]

    if model_iters:
        latest_iter = max(model_iters)
        print(f"The latest model is: model_{latest_iter}")
        return os.path.join(folder_path, f"model_{latest_iter}"), latest_iter
    else:
        print("No model files found in the folder.")
        return None, None

def get_latest_valacc(folder_path, ds='val'):
    assert ds in ['val', 'test']
    files = os.listdir(folder_path)
    model_iters = [int(re.search(fr'{ds}_acc_(\d+).json', file).group(1)) for file in files if re.match(fr'{ds}_acc_(\d+).json', file)]

    if model_iters:
        latest_iter = max(model_iters)
        fname = os.path.join(folder_path, fr'{ds}_acc_{latest_iter}.json')
        with open(fname, "r") as f:
            d = json.load(f)
        return d['validation_accuracy'] if ds == 'val' else d['test_accuracy']
    else:
        return []

def checkdir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

# def remove_masks(model):
#     for name, module in model.net.named_modules():
#         if isinstance(module, nn.Linear):
#             m1 = module
#             m1 = prune.custom_from_mask(m1, name='weight', mask=model.mask[name].to(model.device))
#             prune.remove(m1, 'weight')

def save_files(iter, model):
    save_path = os.path.join(model.get_path(), "model_"+str(iter))
    print(save_path)
    model.save_model(save_path)

def save_val_acc(iter_arr, acc_arr, save_dir, ds='val'):
    assert ds in ['val', 'test']
    if ds == 'val':
        d = {'pruned_iters':iter_arr, 'validation_accuracy':acc_arr}
    else:
        d = {'pruned_iters':iter_arr, 'test_accuracy':acc_arr}
    fname = os.path.join(save_dir, f'{ds}_acc_{len(iter_arr)-1}.json')
    print(fname)
    with open(fname, "w") as f:
        json.dump(d, f)

### Setup - Path for Dataset and Model masks



*   Add path to dataset
*   Add path to where the pruned models and plots will be saved

(I've provided a sample path for both as an example)


In [None]:
DATA_PATH = os.path.expanduser("~/data/sctab/merlin_cxg_2023_05_15_sf-log1p_minimal/")
save_path = os.path.expanduser("~/io/tmp_sparsity")

In [None]:
# manually create data loaders for train and validation set
train_dataset = merlin_dataset_factory(
    join(DATA_PATH, 'train'),
    columns=['cell_type', 'tissue'],
    dataset_kwargs=set_default_kwargs_dataset(training=True)
)

val_dataset = merlin_dataset_factory(
    join(DATA_PATH, 'val'),
    columns=['cell_type', 'tissue'],
    dataset_kwargs=set_default_kwargs_dataset(training=False)
)

test_dataset = merlin_dataset_factory(
    join(DATA_PATH, 'test'),
    columns=['cell_type', 'tissue'],
    dataset_kwargs=set_default_kwargs_dataset(training=False)
)

train_loader = Loader(train_dataset, batch_size=2048, shuffle=True)
val_loader = Loader(val_dataset, batch_size=2048, shuffle=False)
test_loader = Loader(test_dataset, batch_size=2048, shuffle=False)

Loading model - Original model has 8 layers

In [None]:
class MLP(nn.Module):

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_size: int = 128,
        n_hidden: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()
        assert n_hidden >= 1

        modules = [
            nn.Linear(input_dim, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.SiLU(),
            nn.Dropout(p=dropout)
        ]
        for _ in range(1, n_hidden):
            modules += [
                nn.Linear(hidden_size, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.SiLU(),
                nn.Dropout(p=dropout)
            ]

        self.encoder = nn.Sequential(*modules)
        self.linear = nn.Linear(hidden_size, output_dim)

    def forward(self, x):
        return self.linear(self.encoder(x))

In [None]:
import torch
import torch.nn as nn
import numpy as np

class PruningWrapper():
    def __init__(self, net, device, dir):
        self.net = net.to(device)
        self.device = device

        self.net_init_state = copy.deepcopy(self.net.state_dict())
        self.mask = dict()
        for name, module in self.net.named_modules():
            if isinstance(module, nn.Linear):
                self.mask[name] = torch.ones(module.weight.shape)

        self.model_path = dir

    def load_init_weights(self):
        self.net.load_state_dict(state_dict=self.net_init_state)

    def apply_mask_before_train(self):
        for name, module in self.net.named_modules():
            if isinstance(module, nn.Linear):
                module = prune.custom_from_mask(module, 'weight', self.mask[name].to(self.device))

    def update_mask_after_train(self, p):
        with torch.no_grad():
            for name, module in self.net.named_modules():
                if isinstance(module, nn.Linear):
                    prune.remove(module, 'weight')
                    pruner = prune.L1Unstructured(amount=p)
                    self.mask[name] = pruner.compute_mask(module.weight.data, torch.ones_like(module.weight.data))

    def print_model_info(self):
        mask_0, mask_1 = 0, 0
        for _, mask in self.mask.items():
            mask_0 += torch.sum(1 - mask).item()
            mask_1 += torch.sum(mask).item()

        print(f"{(mask_0)/(mask_0 + mask_1)}% of parameters are masked.")

    def get_path(self):
        return self.model_path

    def pruned_proportion(self):
        temp = 0
        total = 0

        for name, module in self.net.named_modules():
            if isinstance(module, nn.Linear):
                temp += module.weight_mask.sum().detach().item()
                total += torch.ones(module.weight.shape).sum().item()
            # break;
        return round(1 - (temp / total), 3)

    # def reinit_net(self, model_pruned):
    #     for name, module in self.net.named_modules():
    #         if isinstance(module, nn.Linear):
    #             m1 = module
    #             prune.remove(m1, 'weight')

    def save_model(self, path):
        d = {
                'model_state_dict': self.net.state_dict(),
                'init_state_dict': self.net_init_state,
                'mask': self.mask,
        }
        torch.save(d, path)

    @classmethod
    def load_model(cls, netType, path, device, dir):
        loader = torch.load(path)
        net_init_state = loader['init_state_dict']
        mask = loader['mask']
        net = netType
        net.load_state_dict(state_dict=loader['model_state_dict'])
        net = net.to(device)

        wrapper = cls(net, device, dir)
        wrapper.net_init_state = net_init_state
        wrapper.mask = mask

        return wrapper

In [None]:
def train(dataloader, epochs, model, optimizer, scheduler, criterion, device):
    model.train()
    correct, total = 0, 0
    train_loss = []
    pbar = tqdm(range(epochs))
    for epoch in pbar:
        train_loss_epoch = []
        correct_epoch, total_epoch = 0, 0
        for ix, (data, _) in enumerate(dataloader):
            if ix % 10 == 0:
                gc.collect()
            inputs, labels = data['X'].to(device), data['cell_type'].to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs, 1)
            total_epoch += labels.size(0)
            correct_epoch += (predicted == labels).sum().item()
            train_loss.append(loss.item())
            train_loss_epoch.append(loss.item())
        total += total_epoch
        correct += correct_epoch
        pbar.set_description(f"acc: {100 * correct_epoch / total_epoch}, loss: {np.mean(train_loss_epoch)}")
        scheduler.step()

    acc = 100 * correct / total
    print(f"TRAIN / Total accuracy: {acc}, Total loss: {np.mean(train_loss)}")
    torch.cuda.empty_cache()
    return np.mean(train_loss)

def test(dataloader, model, criterion, device, ds='val'):
    assert ds in ['val', 'test']
    model.eval()
    correct, total = 0, 0
    test_loss = []
    with torch.no_grad():
        for ix, (data, _) in enumerate(dataloader):
            inputs, labels = data['X'].to(device), data['cell_type'].to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            test_loss.append(loss.item())

    acc = 100 * correct / total
    print(f"{'VAL ' if ds == 'val' else 'TEST'}  / Total accuracy: {acc}, Total loss: {np.mean(test_loss)}")
    torch.cuda.empty_cache()
    return round(acc, 2)

In [None]:
model_save_dir = os.path.join(save_path, "saves")
checkdir(model_save_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

For resumability, I added a function to get latest model run

The model is only saved after each pruning iteration.

In [None]:
latest_path, latest_iter = get_latest_run(model_save_dir)
if(latest_path is None):
    model = PruningWrapper(MLP(input_dim=19331, output_dim=164), device, model_save_dir)
else:
    model = PruningWrapper.load_model(MLP(input_dim=19331, output_dim=164), latest_path, device, model_save_dir)

In [None]:
ITERATION = 10
pruning_amounts = [round(0.1 * (i), 3) for i in range(ITERATION + 1)]

epochs_per_iter = 30
criterion = nn.CrossEntropyLoss()

if(latest_iter is None):
    start_iter = 0
else:
    start_iter = latest_iter + 1

val_acc = get_latest_valacc(model_save_dir, ds='val')
test_acc = get_latest_valacc(model_save_dir, ds='test')

In [None]:
from tqdm import tqdm

for ix in range(start_iter, ITERATION):
    print(f"Pruning iter {ix}")
    model.print_model_info()
    model.load_init_weights()
    model.apply_mask_before_train()
    optimizer = torch.optim.AdamW(model.net.parameters(), lr=0.002, weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    train_loss = train(train_loader, epochs_per_iter, model.net, optimizer, scheduler, criterion, device)
    val_acc.append(test(val_loader, model.net, criterion, device, ds='val'))
    test_acc.append(test(test_loader, model.net, criterion, device, ds='test'))

    amt = pruning_amounts[ix + 1]
    model.update_mask_after_train(amt)

    save_files(ix, model)
    save_val_acc(pruning_amounts[:len(val_acc)], val_acc, model_save_dir, ds='val')
    save_val_acc(pruning_amounts[:len(test_acc)], test_acc, model_save_dir, ds='test')

To start from scratch, delete the saves folder and rerun the notebook

## Evaluation

In [None]:
def get_preds(model, dataloader, device):
    """
    Get all predictions of model by iterating over dataloader
    """

    model.net.eval()
    criterion = nn.CrossEntropyLoss()
    final_labels = torch.Tensor().to(device)
    final_preds = torch.Tensor().to(device)
    tissues = torch.Tensor().to(device)

    with torch.no_grad():
        for ix, (data, _) in enumerate(dataloader):
            inputs, labels = data['X'].to(device), data['cell_type'].to(device)
            tissue = data['tissue'].to(device)
            
            outputs = model.net(inputs)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs, 1)

            final_labels = torch.cat((final_labels, labels))
            final_preds = torch.cat((final_preds, predicted))
            tissues = torch.cat((tissues, tissue))

    torch.cuda.empty_cache()

    d = {'labels':final_labels, 'preds':final_preds, 'tissue':tissues}
    return d

def correct_labels(y_true: np.ndarray, y_pred: np.ndarray, child_matrix: np.ndarray):
    """
    Update predictions.
    If prediction is actually a child node of the true label -> update prediction to true value.

    E.g: Label='T cell' and prediction='CD8 positive T cell' -> update prediction to 'T cell'
    """
    updated_predictions = y_pred.copy()
    # precalculate child nodes
    child_nodes = {i: np.where(child_matrix[i, :])[0] for i in range(child_matrix.shape[0])}

    for i, (pred, true_label) in enumerate(zip(y_pred, y_true)):
        if pred in child_nodes[true_label]:
            updated_predictions[i] = true_label
        else:
            updated_predictions[i] = pred

    return updated_predictions

def get_f1_score(path, cell_type_hierarchy):

    data = torch.load(path)
    y_true, y_preds = data['labels'].cpu().numpy(), data['preds'].cpu().numpy()

    y_preds_corr = correct_labels(y_true, y_preds, cell_type_hierarchy)

    clf_report = pd.DataFrame(classification_report(y_true, y_preds_corr, output_dict=True)).T
    clf_report_overall = clf_report.iloc[-3:].copy()
    clf_report_per_class = clf_report.iloc[:-3].copy()

    return clf_report_overall, clf_report_per_class

def save_scores(name, versions, input_dirname, output_dirname, cell_type_hierarchy):
    if(name == "val"):
        f = "val_preds"
    elif(name == "test"):
        f = "test_preds"
    else:
        return 0
    
    for i in versions:
        file_path = os.path.join(input_dirname, f"{f}_{i}")
        print(file_path)
        f1_overall, f1_per_class = get_f1_score(file_path, cell_type_hierarchy)
        output_fname_overall = os.path.join(output_dirname, f"f1_overall_{i}.csv")
        output_fname_perclass = os.path.join(output_dirname, f"f1_perclass_{i}.csv")
        f1_overall.to_csv(output_fname_overall)
        f1_per_class.to_csv(output_fname_perclass)

### Saving Predictions for Val and Test Dataset

In [None]:
files = os.listdir(model_save_dir)
model_iters = sorted([int(re.search(r'model_(\d+)', file).group(1)) for file in files if re.match(r'model_(\d+)', file)])

preds_save_dir = os.path.join(save_path, "preds")
checkdir(preds_save_dir)

for i in model_iters:
    model_path = os.path.join(model_save_dir, f"model_{i}")
    model = PruningWrapper.load_model(MLP(input_dim=19331, output_dim=164), model_path, device, model_save_dir)

    valpred_path = os.path.join(preds_save_dir, f"val_preds_{i}")
    testpred_path = os.path.join(preds_save_dir, f"test_preds_{i}")
    val_pred_dict = get_preds(model, val_loader, device)
    test_pred_dict = get_preds(model, test_loader, device)

    torch.save(val_pred_dict, valpred_path)
    torch.save(test_pred_dict, testpred_path)

### Creating and Saving F1-Score (overall and per cell type)

Initialising directory and reading cell_type_hierarchy matrix

In [None]:
scores_save_dir = os.path.join(save_path, "f1_scores")
checkdir(scores_save_dir)

cell_type_hierarchy = np.load(join(DATA_PATH, 'cell_type_hierarchy/child_matrix.npy'))
files = os.listdir(preds_save_dir)

Saving F1-scores for Validation Dataset

In [None]:
val_pred_iters = sorted([int(re.search(r'val_preds_(\d+)', file).group(1)) for file in files if re.match(r'val_preds_(\d+)', file)])

val_scores_save_dir = os.path.join(scores_save_dir, "val")
checkdir(val_scores_save_dir)

save_scores("val", val_pred_iters, preds_save_dir, val_scores_save_dir, cell_type_hierarchy)

Saving F1-scores for Test Dataset

In [None]:
test_pred_iters = sorted([int(re.search(r'test_preds_(\d+)', file).group(1)) for file in files if re.match(r'test_preds_(\d+)', file)])

test_scores_save_dir = os.path.join(scores_save_dir, "test")
checkdir(test_scores_save_dir)

save_scores("test", val_pred_iters, preds_save_dir, test_scores_save_dir, cell_type_hierarchy)