Setup the seed for training:

In [1]:
import random
import torch
import os
import numpy as np
import torch.utils.data



def setup_seed(seed):
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True

setup_seed(3407)

Definet the device used for training:

In [2]:
import os
import torch
import warnings
warnings.filterwarnings('ignore')

gpu_list = [0]
gpu_list_str = ','.join(map(str, gpu_list))
os.environ.setdefault("CUDA_VISIBLE_DEVICES", gpu_list_str)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

Define the `Hist2Cell` model, and load the model on GPU:

In [3]:
from torch.nn import Linear
import torch.nn as nn
import torchvision.models as models
from torch_geometric.nn import GATv2Conv, LayerNorm
import sys,os
sys.path.append(os.path.dirname(os.getcwd()))
from model.ViT import Mlp, VisionTransformer

class Hist2Cell(nn.Module):
    def __init__(self, cell_dim=80, vit_depth=3):
        super(Hist2Cell, self).__init__()
        self.resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.resnet18 = torch.nn.Sequential(*list(self.resnet18.children())[:-1])
        
        self.embed_dim = 32 * 8
        self.head = 8
        self.dropout = 0.3
        
        self.conv1 = GATv2Conv(in_channels=512, out_channels=int(self.embed_dim/self.head), heads=self.head)
        self.norm1 = LayerNorm(in_channels=self.embed_dim)
        
        self.cell_transformer = VisionTransformer(num_classes=cell_dim, embed_dim=self.embed_dim, depth=vit_depth,
                                                  mlp_head=True, drop_rate=self.dropout, attn_drop_rate=self.dropout)
        self.spot_fc = Linear(in_features=512, out_features=256)
        self.spot_head = Mlp(in_features=256, hidden_features=512*2, out_features=cell_dim)
        self.local_head = Mlp(in_features=256, hidden_features=512*2, out_features=cell_dim)
        self.fused_head = Mlp(in_features=256, hidden_features=512*2, out_features=cell_dim)
    
    
    def forward(self, x, edge_index):
        x_spot = self.resnet18(x)
        x_spot = x_spot.squeeze()
        
        x_local = self.conv1(x=x_spot, edge_index=edge_index)
        x_local = self.norm1(x_local)
        
        x_local = x_local.unsqueeze(0)
        
        x_cell = x_local
        
        x_spot = self.spot_fc(x_spot)
        cell_predication_spot = self.spot_head(x_spot)
        x_local = x_local.squeeze(0)
        cell_prediction_local = self.local_head(x_local)
        cell_prediction_global, x_global = self.cell_transformer(x_cell)
        cell_prediction_global = cell_prediction_global.squeeze()
        x_global = x_global.squeeze()
        cell_prediction_fused = self.fused_head((x_spot+x_local+x_global)/3.0)
        cell_prediction = (cell_predication_spot + cell_prediction_local + cell_prediction_global + cell_prediction_fused) / 4.0
        
        return cell_prediction
    
    
model = Hist2Cell(vit_depth=3)
model = model.to(device)

Load train/test split file, here we train `Hist2Cell` on other 3 donors in the humanlung cell2location dataset, and test `Hist2Cell` on the left donnor A50:

There are 2 slides from donor A50 in humanlung cell2location dataset: 
- WSA_LngSP9258463
- WSA_LngSP9258467

The slides from the other 3 donors used for training are:
- WSA_LngSP8759311
- WSA_LngSP8759312
- WSA_LngSP8759313
- WSA_LngSP9258464
- WSA_LngSP9258468
- WSA_LngSP10193347
- WSA_LngSP10193348
- WSA_LngSP10193345
- WSA_LngSP10193346

In [10]:
train_slides = open("../train_test_splits/humanlung_cell2location/test_leave_New.txt").read().split('\n')
test_slides = open("../train_test_splits/humanlung_cell2location/test_leave_A50.txt").read().split('\n')

Load the processed data for each slide as the train/test dataset:

In [11]:
from torch_geometric.data import Batch


train_graph_list = list()
for item in train_slides:
    train_graph_list.append(torch.load(os.path.join("../patch/output/pts", item+'.pt')))
train_dataset = Batch.from_data_list(train_graph_list)

test_graph_list = list()
for item in test_slides:
    test_graph_list.append(torch.load(os.path.join("../example_data/humanlung_cell2location", item+'.pt')))
test_dataset = Batch.from_data_list(test_graph_list)  

Define the `DataLoader` for train/test dataset, here are 2 important parameters:
- `hop`: this parameter define receptive field when sampling the subgraphs with a group of center nodes for training/testing, in our paper, we use 2-hop subgraphs to achieve a banlance between computation cost and performance, generally, bigger receptive field will contain more neighboring information.
- `subgraph_bs`: this parameter define the number of subgraphs to be sampled during training/testing, which is the `subgraph batchsize`, we use `subgraph_bs=16` on our RTX 3090 GPU.

In [12]:
from torch_geometric.loader import NeighborLoader
import torch_geometric
torch_geometric.typing.WITH_PYG_LIB = False


hop = 2
subgraph_bs = 16

train_loader = NeighborLoader(
    train_dataset,
    num_neighbors=[-1]*hop,
    batch_size=subgraph_bs,
    directed=False,
    input_nodes=None,
    shuffle=True,
    num_workers=0,
)

test_loader = NeighborLoader(
    test_dataset,
    num_neighbors=[-1]*hop,
    batch_size=subgraph_bs,
    directed=False,
    input_nodes=None,
    shuffle=False,
    num_workers=0,
)

Define the `learning rate`, `criterion`, `optimizer` and `scheduler` used for training, we use `lr=1e-4` in our study:

In [13]:
lr = 1e-4

params = model.parameters()
criterion = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(params, lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-5, last_epoch=-1, verbose=False)

For a simple example, we train `Hist2Cell` for 5 epochs, and save the checkpoint with the best Pearson R:

In [9]:
import numpy as np
from scipy.stats import pearsonr
import time


num_epochs = 5

best_cell_abundance_all_average = 0.0
since = time.time()
for epoch in range(num_epochs):
    model.train()
    print("---------------------------------------"*4)
    print('Epoch: {} \t'.format(epoch + 1))
    print('lr = ',optimizer.param_groups[0]["lr"])
    
    train_sample_num = 0
    train_cell_pred_array = []
    train_cell_label_array = []
    train_cell_abundance_loss = 0
    for graph in train_loader:
        x = graph.x.to(device)
        y = graph.y.to(device)
        edge_index = graph.edge_index.to(device)
        cell_label = y[:, 5:]
        
        cell_pred = model(x=x, edge_index=edge_index)

        cell_loss = criterion(cell_pred, cell_label)        
        optimizer.zero_grad()
        cell_loss.backward()
        optimizer.step()
            
        center_num = len(graph.input_id)
        center_cell_label = cell_label[:center_num, :]
        center_cell_pred = cell_pred[:center_num, :]
        train_cell_label_array.append(center_cell_label.squeeze().cpu().detach().numpy())
        train_cell_pred_array.append(center_cell_pred.squeeze().cpu().detach().numpy())
        train_sample_num = train_sample_num + center_num
        train_cell_abundance_loss += cell_loss.item() * center_num

    train_cell_abundance_loss = train_cell_abundance_loss / train_sample_num
    
    if len(train_cell_pred_array[-1].shape) == 1:
        train_cell_pred_array[-1] = np.expand_dims(train_cell_pred_array[-1], axis=0)
    train_cell_pred_array = np.concatenate(train_cell_pred_array)
    if len(train_cell_label_array[-1].shape) == 1:
        train_cell_label_array[-1] = np.expand_dims(train_cell_label_array[-1], axis=0)
    train_cell_label_array = np.concatenate(train_cell_label_array)

    train_cell_abundance_all_pearson_average = 0.0
    for i in range(train_cell_pred_array.shape[1]):
        r, p = pearsonr(train_cell_pred_array[:, i], train_cell_label_array[:, i])
        train_cell_abundance_all_pearson_average = train_cell_abundance_all_pearson_average + r
    train_cell_abundance_all_pearson_average = train_cell_abundance_all_pearson_average / train_cell_pred_array.shape[1]
    
    scheduler.step()

    with torch.no_grad():
        model.eval()

        test_sample_num = 0
        test_cell_pred_array = []
        test_cell_label_array = []
        test_cell_abundance_loss = 0
        for graph in test_loader:
            x = graph.x.to(device)
            y = graph.y.to(device)
            edge_index = graph.edge_index.to(device)
            cell_label = y[:, 5:]
            
            cell_pred = model(x=x, edge_index=edge_index)

            cell_loss = criterion(cell_pred, cell_label)

            center_num = len(graph.input_id)
            center_cell_label = cell_label[:center_num, :]
            center_cell_pred = cell_pred[:center_num, :]
            
            test_cell_label_array.append(center_cell_label.squeeze().cpu().detach().numpy())
            test_cell_pred_array.append(center_cell_pred.squeeze().cpu().detach().numpy())
            test_sample_num = test_sample_num + center_num
            
            test_cell_abundance_loss += cell_loss.item() * center_num
            
        test_cell_abundance_loss = test_cell_abundance_loss / test_sample_num
 
    if len(test_cell_pred_array[-1].shape) == 1:
        test_cell_pred_array[-1] = np.expand_dims(test_cell_pred_array[-1], axis=0)
    test_cell_pred_array = np.concatenate(test_cell_pred_array)
    if len(test_cell_label_array[-1].shape) == 1:
        test_cell_label_array[-1] = np.expand_dims(test_cell_label_array[-1], axis=0)
    test_cell_label_array = np.concatenate(test_cell_label_array)
        
    test_cell_abundance_all_pearson_average = 0.0
    for i in range(test_cell_pred_array.shape[1]):
        r, p = pearsonr(test_cell_pred_array[:, i], test_cell_label_array[:, i])
        test_cell_abundance_all_pearson_average = test_cell_abundance_all_pearson_average + r
    test_cell_abundance_all_pearson_average = test_cell_abundance_all_pearson_average / test_cell_pred_array.shape[1]

    if test_cell_abundance_all_pearson_average > best_cell_abundance_all_average:
        best_cell_abundance_all_average = test_cell_abundance_all_pearson_average
        torch.save(model.state_dict(), os.path.join("../model_weights", "demo_ckpt.pth"))
        print("saving " + "best cell all abundance average " + str(test_cell_abundance_all_pearson_average))

    time_elapsed = time.time() - since
    print(f'Training complete in {(time_elapsed // 60):.0f}m {(time_elapsed % 60):.0f}s')
    print(f'Epoch: {(epoch + 1)} \tTraining Cell abundance Loss: {train_cell_abundance_loss:.6f}')
    print(f'Epoch: {(epoch + 1)} \tTraining Cell abundance pearson all average: {train_cell_abundance_all_pearson_average:.6f}')
    print(f'Epoch: {(epoch + 1)} \tTest Cell abundance Loss: {test_cell_abundance_loss:.6f}')
    print(f'Epoch: {(epoch + 1)} \tTest Cell abundance pearson all average: {test_cell_abundance_all_pearson_average:.6f}')

------------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch: 1 	
lr =  0.0001


KeyboardInterrupt: 

In [14]:
for graph in train_loader:
        
        x = graph.x.to(device)
        
        y = graph.y.to(device)
        edge_index = graph.edge_index.to(device)
        cell_label = y[:, 5:]
        cell_pred = model(x=x, edge_index=edge_index)
        print(cell_pred)        

tensor([[-0.0370,  0.0235,  0.0686,  ...,  0.1223, -0.0086,  0.0427],
        [ 0.0347, -0.0085,  0.0185,  ...,  0.1063,  0.0131,  0.0358],
        [-0.0476,  0.1360,  0.1304,  ...,  0.0942,  0.0430,  0.0147],
        ...,
        [-0.0139, -0.0156,  0.0503,  ...,  0.1532, -0.0064, -0.0480],
        [-0.0058,  0.0072,  0.0547,  ...,  0.1019,  0.0237, -0.0599],
        [-0.0116, -0.0170,  0.0524,  ...,  0.1216,  0.0048,  0.0431]],
       device='cuda:0', grad_fn=<DivBackward0>)


OutOfMemoryError: CUDA out of memory. Tried to allocate 396.00 MiB (GPU 0; 6.00 GiB total capacity; 12.06 GiB already allocated; 0 bytes free; 12.43 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF