# Demo: Apply TissueFormer to user-provided dataset

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
import argparse
import os
import sys

parent_dir = os.path.abspath("..")
sys.path.insert(0, parent_dir)

from run_finetune import run_update, run_train, run_test, evaluate
from parse import parse_pretrain_method, parse_regression_method, parse_classification_method, parser_add_main_args
from utils import dataset_create, dataset_create_split, k_shot_split
from dataloader import CustomDataset

from torch_cluster import knn_graph
import scanpy as sc

import warnings

os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')



## Load dataset

Here we use the lung fibrosis data as an example

In [9]:
def data_load(dir_path, sample):
    file_path = os.path.join(dir_path, sample) + '.h5ad'
    adata = sc.read(file_path)
    X_log1p = adata.layers['X_log1p']

    gene_mask = adata.var['gene_filter_mask']
    cell_by_gene = X_log1p[:, gene_mask]
    gene_index = adata.var['gene_filtered_idx'][gene_mask]

    cell_image_emb = adata.obsm['embeddings']
    cell_location = adata.obsm['centroids']

    dataset = {}
    dataset['x'] = torch.tensor(cell_image_emb, dtype=torch.float)
    dataset['y'] = torch.tensor(cell_by_gene, dtype=torch.float)
    dataset['gene_idx'] = torch.tensor(gene_index, dtype=torch.long)
    dataset['edge_index'] = knn_graph(torch.tensor(cell_location, dtype=torch.float), k=5, loop=False)

    hvg_gene_rank = adata.var['highly_variable_rank'][gene_mask]
    dataset['hvg_gene_rank'] = torch.tensor(hvg_gene_rank, dtype=torch.long)

    return dataset

def dataset_create(dir_path, samples):
    datasets = []
    pbar = tqdm(samples, desc='loading dataset', ncols=100, ascii=True)
    for s in pbar:
        dataset = data_load(dir_path, s)
        idx = torch.arange(dataset['x'].shape[0])
        datasets.append(dataset)
        pbar.clear()
        pbar.refresh()
    return CustomDataset(datasets)

dir_path = '/ewsc/wuqitian/lung_preprocess'
meta_info = pd.read_csv("../../data/meta_info_lung.csv")

# train data can be used as the reference for in-context learning or for finetuning the model
train_samples = meta_info[meta_info['affect'] == 'Unaffected']['sample'].tolist()[:-1]

# test data for evaluation
test_samples = meta_info[meta_info['affect'] == 'Unaffected']['sample'].tolist()[-1:]

# create dataloader
train_datasets = dataset_create(dir_path, train_samples)
train_dataloader = DataLoader(train_datasets, batch_size=1, shuffle=True)
test_datasets = dataset_create(dir_path, test_samples)
test_dataloader = DataLoader(test_datasets, batch_size=1, shuffle=False)

loading dataset: 100%|################################################| 6/6 [00:12<00:00,  2.06s/it]
loading dataset: 100%|################################################| 1/1 [00:00<00:00,  1.06it/s]


## Model preparation: load pretrained checkpoints

In [12]:
import models.regression as regression
import encoders

device = torch.device("cuda:0") # torch.device("cpu")
encoder1 = encoders.Transformer(in_channels=1536, hidden_channels=1024, num_layers_prop=2, num_layers_mlp=2, num_attn_heads=1,
                                        dropout=0., use_bn=True, use_graph=True, use_residual=True).to(device)
model_ours = regression.InContext_Predict(encoder1, hidden_channels=1024, out_channels=340, batch_size=100,
                                    num_neighbors=1000, device=device).to(device)
    
pretrained_state_dict = torch.load('../../model_checkpoints/ours_pretrain_xenium_lung.pth') # can specify the model versions
encoder1_pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if k.startswith("encoder1.")}
model_state_dict = model_ours.state_dict()
encoder1_model_dict = {k: v for k, v in model_state_dict.items() if k.startswith("encoder1.")}
for k, v in encoder1_pretrained_dict.items():
    assert (k in encoder1_model_dict)
    assert (v.size() == encoder1_model_dict[k].size())
model_state_dict.update(encoder1_pretrained_dict)
model_ours.load_state_dict(model_state_dict)

<All keys matched successfully>

## Predict gene expression from histology images

In [17]:
def update(model, dataloader, device, use_gene_idx=True):
    model.eval()
    for dataset in dataloader:
        data = dataset[0]
        inputs, labels = data.x.to(device), data.y.to(device)
        edge_index = data.edge_index.to(device)
        if use_gene_idx:
            gene_idx = data.gene_idx.to(device)
        else:
            gene_idx = None
        train_idx = torch.arange(labels.shape[0])

        model.update(inputs, labels, gene_idx, train_idx, edge_index)

@torch.no_grad()
def predict(model, data, device, use_gene_idx=True):
    model.eval()

    inputs, labels = data.x.to(device), data.y.to(device)
    edge_index = data.edge_index.to(device)
    if use_gene_idx:
        gene_idx = data.gene_idx.to(device)
    else:
        gene_idx = None

    # task is gene expression prediction
    outputs = model(inputs, gene_idx, edge_index)

    return labels.cpu().numpy(), outputs.cpu().numpy()

# load reference data into the model for in-context learning prediction
update(model_ours, train_dataloader, device, use_gene_idx=False)
print("loading reference data finished")

# predict gene expression for each test slide
for i, dataset in enumerate(test_dataloader):
    data = dataset[0]
    y_true, y_pred = predict(model_ours, data, device, use_gene_idx=False)
    print(f"prediction for sample {i} finished")
    print(y_true.shape)
    print(y_pred.shape)

loading reference data finished
prediction for sample 0 finished
(12870, 340)
(12870, 340)


## Extract cell embeddings and cell-cell attention maps

In [19]:
@torch.no_grad()
def get_embs_attn(model, data, device):
    model.eval()

    inputs = data.x.to(device)
    edge_index = data.edge_index.to(device)

    embs = model_ours.encoder1.get_embeddings(inputs, edge_index).cpu().numpy() # [layer num, cell num, hidden size]
    attn_maps = model_ours.encoder1.get_attentions(inputs, edge_index).cpu().numpy() # [layer num, cell num, cell num]

    return embs, attn_maps 

# get embeddings and attention maps for each slide
for i, dataset in enumerate(test_dataloader):
    data = dataset[0]
    embs, attn_maps = get_embs_attn(model_ours, data, device)
    print(embs.shape)
    print(embs)
    print(attn_maps.shape)
    print(attn_maps)


(3, 12870, 1024)
[[[ 1.3715379e+00  1.4454957e+00 -7.5154972e-01 ... -7.6502818e-01
   -7.3900068e-01 -7.2247493e-01]
  [ 9.2117399e-01  1.1292846e+00 -6.5943491e-01 ... -6.6768020e-01
   -6.6846138e-01 -6.4686137e-01]
  [ 1.7809750e+00  1.5489399e+00 -7.9155129e-01 ... -8.0272007e-01
   -8.1195861e-01 -7.8880268e-01]
  ...
  [-6.4784706e-01 -5.3436673e-01  7.5236696e-01 ...  7.0106542e-01
    7.0526987e-01  8.0448413e-01]
  [-6.2402189e-01 -4.1940057e-01  5.1424354e-01 ...  4.7407335e-01
    5.0584358e-01  8.2884425e-01]
  [ 1.6685019e+00  1.8540162e+00 -8.1977361e-01 ... -8.1474954e-01
   -8.3234191e-01 -8.5177535e-01]]

 [[-4.8717000e-02  2.1894351e-01 -1.6253795e-01 ... -5.8698375e-02
   -4.6126145e-01 -4.6564618e-01]
  [-2.2871874e-01  2.4311823e-01 -1.6629460e-01 ... -7.7653445e-02
   -4.2097694e-01 -3.9412290e-01]
  [ 2.2985333e-01  1.8899480e-01 -2.4468711e-01 ...  6.5168940e-02
   -4.6890473e-01 -4.5090723e-01]
  ...
  [-6.9585079e-01 -5.0514692e-01 -5.6706842e-02 ... -2.76369

## Finetune the model on new dataset

In [36]:
from models.pretrain import Model_Pretrain

def train(model, dataloader, optimizer, device, accumulate_steps=1):
    model.train()
    running_loss = 0.

    for i, dataset in enumerate(dataloader):
        data = dataset[0]
        x1, x2 = data.x.to(device), data.y.to(device)
        edge_index = data.edge_index.to(device)
        gene_idx = data.gene_idx.to(device)
        train_idx = torch.arange(x1.shape[0]).to(device)

        loss = model.loss(x1, x2, gene_idx, train_idx, edge_index)
        loss.backward()
        running_loss += loss.item()

        if (i + 1) % accumulate_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    return running_loss / len(dataloader)

# load pretrain model
encoder1 = encoders.Transformer(in_channels=1536, hidden_channels=1024, num_layers_prop=2, num_layers_mlp=2, num_attn_heads=1,
                                        dropout=0., use_bn=True, use_graph=True, use_residual=True).to(device)
encoder2 = encoders.MLP(in_channels=256, hidden_channels=1024, num_layers=1, dropout=0.).to(device)
gene_embeddings = torch.zeros((23258, 256), dtype=torch.float).to(device)
model_pretrain = Model_Pretrain(encoder1, encoder2, gene_embeddings, reg_w=0.5, ge_trainable=True, ge_pretrained=False).to(device)
pretrained_state_dict = torch.load('../../model_checkpoints/ours_pretrain_visium_all.pth')
model_pretrain.load_state_dict(pretrained_state_dict)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model_pretrain.parameters()), lr=1e-5,
                       weight_decay=0.)

# finetune the model
epoch_nums = 50
for epoch in range(epoch_nums):
    train_loss = train(model_pretrain, train_dataloader, optimizer, device)

    print(f'Epoch [{epoch + 1}/{epoch_nums}], Loss: {train_loss:.4f}')

Epoch [1/50], Loss: -5.3346
Epoch [2/50], Loss: -5.3493
Epoch [3/50], Loss: -5.3589
Epoch [4/50], Loss: -5.3656
Epoch [5/50], Loss: -5.3722
Epoch [6/50], Loss: -5.3781
Epoch [7/50], Loss: -5.3848
Epoch [8/50], Loss: -5.3904
Epoch [9/50], Loss: -5.3966
Epoch [10/50], Loss: -5.4022
Epoch [11/50], Loss: -5.4073
Epoch [12/50], Loss: -5.4117
Epoch [13/50], Loss: -5.4131
Epoch [14/50], Loss: -5.4132
Epoch [15/50], Loss: -5.4165
Epoch [16/50], Loss: -5.4155
Epoch [17/50], Loss: -5.4245
Epoch [18/50], Loss: -5.4276
Epoch [19/50], Loss: -5.4311
Epoch [20/50], Loss: -5.4335
Epoch [21/50], Loss: -5.4345
Epoch [22/50], Loss: -5.4409
Epoch [23/50], Loss: -5.4424
Epoch [24/50], Loss: -5.4431
Epoch [25/50], Loss: -5.4482
Epoch [26/50], Loss: -5.4485
Epoch [27/50], Loss: -5.4530
Epoch [28/50], Loss: -5.4558
Epoch [29/50], Loss: -5.4577
Epoch [30/50], Loss: -5.4621
Epoch [31/50], Loss: -5.4632
Epoch [32/50], Loss: -5.4624
Epoch [33/50], Loss: -5.4671
Epoch [34/50], Loss: -5.4687
Epoch [35/50], Loss: -5