In [3]:
import matplotlib.pyplot as plt
import torch
from spatialSSL.Dataloader import FullImageDatasetConstructor
from spatialSSL.Utils import split_dataset
from spatialSSL.Training import train
from spatialSSL.Training import train_epoch
from spatialSSL.Testing import test
from spatialSSL.Dataset import InMemoryGraphDataset
from torch import nn, optim, Tensor
import torch
from torch_geometric.nn import GCNConv,GATConv
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import r2_score
from torch.nn import LeakyReLU, Dropout
import time
from torch.utils.checkpoint import checkpoint
import numpy as np

In [4]:
import torch.nn.functional as F

In [5]:
class PretrainedGAT_2(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout_rate=0.5):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels)
        self.conv2 = GATConv(hidden_channels, out_channels)
        self.dropout = Dropout(dropout_rate)
        self.act = nn.LeakyReLU()

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = self.dropout(self.act(self.conv1(x, edge_index)))
        x = self.act(self.conv2( x, edge_index)) # Using checkpointing on the final layer
        return x
    

class CellTypeClassificationModel(nn.Module):
    def __init__(self, pre_trained_model, num_cell_types):
        super(CellTypeClassificationModel, self).__init__()
        # Copy the pre-trained model's architecture and weights
        self.conv1 = pre_trained_model.conv1
        self.dropout = pre_trained_model.dropout
        self.act = pre_trained_model.act

        # Replace the last GAT layer with a new one for classification
        in_channels = pre_trained_model.conv2.in_channels
        self.conv2 = GATConv(in_channels, num_cell_types)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = self.dropout(self.act(self.conv1(x, edge_index)))
        x = self.act(self.conv2(x, edge_index))
        return x


In [6]:
file_path = "./data/subset_6img_atlas_brain.h5ad"

# Create the dataloader
dataset_constructor = FullImageDatasetConstructor(file_path=file_path, image_col="section",
                                                                     label_col="class_label", include_label=False,
                                                                     radius=30, node_level=1,downstream = True)

# Load the data
dataset_constructor.load_data()

# Construct the graph
dataset = dataset_constructor.construct_graph()


total_cells = len(dataset_constructor.adata)
print(total_cells)

  self.adata.obs['cell_type_encoded'] = encode_cell_type.transform(self.adata.obs[self.label_col].values)


Constructing Graphs:   0%|          | 0/6 [00:00<?, ?it/s]

240945


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = "./models/GAT_2_random20_all_model.pt"

# Load the pretrained model without the linear layer
pretrained_model = PretrainedGAT_2(550, 550, 550).to(device)
#pretrained_model.load_state_dict(torch.load(model_path))
pretrained_model.eval()

PretrainedGAT_2(
  (conv1): GATConv(550, 550, heads=1)
  (conv2): GATConv(550, 550, heads=1)
  (dropout): Dropout(p=0.5, inplace=False)
  (act): LeakyReLU(negative_slope=0.01)
)

In [6]:
# Now create the fine-tuned model for cell type classification
num_cell_types = 24 # Number of cell types you want to classify
cell_type_model = CellTypeClassificationModel(pretrained_model, num_cell_types)
#cell_type_model.conv1.load_state_dict(pretrained_model.conv1.state_dict())
cell_type_model

CellTypeClassificationModel(
  (conv1): GATConv(550, 550, heads=1)
  (dropout): Dropout(p=0.5, inplace=False)
  (act): LeakyReLU(negative_slope=0.01)
  (conv2): GATConv(550, 24, heads=1)
)

In [7]:
# Copy the weights from the pretrained model to the new model with the linear layer
#model = GAT_2(550, 550, 550, 24).to(device)
#model.conv1.load_state_dict(pretrained_model.conv1.state_dict())
#model.conv2.load_state_dict(pretrained_model.conv2.state_dict())
#model.conv3.load_state_dict(pretrained_model.conv3.state_dict())

In [8]:
train_loader, val_loader, test_loader = split_dataset(dataset,split_percent=(0.8, 0.2, 0.2), batch_size=1)

In [9]:
# Print out the size of each set to verify
print(f"Train size: {len(train_loader.dataset)}")
print(f"Validation size: {len(val_loader.dataset)}")
print(f"Test size: {len(test_loader.dataset)}")

Train size: 9
Validation size: 2
Test size: 1


In [30]:
for x in train_loader:
    print(x)


DataBatch(x=[51012, 550], edge_index=[2, 419170], y=[5101, 550], mask=[51012], cell_type=[51012], cell_type_masked=[5101], image=[1], num_nodes=51012, batch=[51012], ptr=[2])
DataBatch(x=[29300, 550], edge_index=[2, 259136], y=[2930, 550], mask=[29300], cell_type=[29300], cell_type_masked=[2930], image=[1], num_nodes=29300, batch=[29300], ptr=[2])
DataBatch(x=[29300, 550], edge_index=[2, 259136], y=[2930, 550], mask=[29300], cell_type=[29300], cell_type_masked=[2930], image=[1], num_nodes=29300, batch=[29300], ptr=[2])
DataBatch(x=[47461, 550], edge_index=[2, 410372], y=[4746, 550], mask=[47461], cell_type=[47461], cell_type_masked=[4746], image=[1], num_nodes=47461, batch=[47461], ptr=[2])
DataBatch(x=[47461, 550], edge_index=[2, 410372], y=[4746, 550], mask=[47461], cell_type=[47461], cell_type_masked=[4746], image=[1], num_nodes=47461, batch=[47461], ptr=[2])
DataBatch(x=[36043, 550], edge_index=[2, 250394], y=[3604, 550], mask=[36043], cell_type=[36043], cell_type_masked=[3604], im

In [10]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cell_type_model.parameters(), lr=0.001)
#scheduler = StepLR(optimizer, step_size=80, gamma=0.1) # learning rate scheduler


In [11]:
# use gene expression and adjcent gene expression to predict the cell types

In [12]:
train_loader

<torch_geometric.loader.dataloader.DataLoader at 0x7f7ab43bae90>

In [None]:
from tqdm.auto import tqdm
from torch import optim
from torcheval.metrics import MulticlassAccuracy, MulticlassAUPRC, MulticlassConfusionMatrix

# use masked gene expression and adjacent gene expression to predict the cell types
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cell_type_model.to(device)
# Train the model
epochs = 100

multiclass_accuracy = MulticlassAccuracy()
multiclass_auprc = MulticlassAUPRC(num_classes = 24)
multiclass_confusion_matrix = MulticlassConfusionMatrix(num_classes = 24)

for epoch in tqdm(range(epochs)):

    cell_type_model.train()
    all_accuracy = []
    
    running_loss = 0.0
    for data in tqdm(train_loader):
        
        inputs = data.x.float().to(device)
        labels = data.cell_type_masked.long().to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward + backward + optimize
        outputs = cell_type_model(inputs, data.edge_index.to(device))
        #print(outputs.shape)
        loss = criterion(outputs[data.mask], labels)
        #print(outputs[data.mask].shape)
        loss.backward()
        optimizer.step()

        # Compute additional metrics
        predicted = outputs[data.mask].cpu()
        multiclass_accuracy.update(predicted, labels.cpu())
        multiclass_auprc.update(predicted, labels.cpu())
        multiclass_confusion_matrix.update(predicted, labels.cpu())

        # Print statistics
        #accuracy = (predicted == labels.cpu()).sum().item() / len(labels)
        #all_accuracy.append(accuracy)
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.3f}')
    print('Multiclass Accuracy:', multiclass_accuracy.compute())
    print('Multiclass AUPRC:', multiclass_auprc.compute())
    #print('Multiclass Confusion Matrix:', multiclass_confusion_matrix.compute())

    # Reset metrics after each epoch if needed
    #multiclass_accuracy.reset()
    #multiclass_auprc.reset()
    #multiclass_confusion_matrix.reset()

print('Training finished!')


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 1, Loss: 2.793
Multiclass Accuracy: tensor(0.1470)
Multiclass AUPRC: tensor(0.0415)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 2, Loss: 2.155
Multiclass Accuracy: tensor(0.1616)
Multiclass AUPRC: tensor(0.0418)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 3, Loss: 2.109
Multiclass Accuracy: tensor(0.1660)
Multiclass AUPRC: tensor(0.0418)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 4, Loss: 2.091
Multiclass Accuracy: tensor(0.1688)
Multiclass AUPRC: tensor(0.0418)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 5, Loss: 2.076
Multiclass Accuracy: tensor(0.1730)
Multiclass AUPRC: tensor(0.0420)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 6, Loss: 2.087
Multiclass Accuracy: tensor(0.1749)
Multiclass AUPRC: tensor(0.0420)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 7, Loss: 2.067
Multiclass Accuracy: tensor(0.1765)
Multiclass AUPRC: tensor(0.0420)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 8, Loss: 2.061
Multiclass Accuracy: tensor(0.1776)
Multiclass AUPRC: tensor(0.0420)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 9, Loss: 2.062
Multiclass Accuracy: tensor(0.1788)
Multiclass AUPRC: tensor(0.0421)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 10, Loss: 2.059
Multiclass Accuracy: tensor(0.1799)
Multiclass AUPRC: tensor(0.0423)


  0%|          | 0/9 [00:00<?, ?it/s]

In [None]:
from tqdm.auto import tqdm
from torch import optim
from torcheval.metrics import MulticlassAccuracy, MulticlassAUPRC, MulticlassConfusionMatrix

# use masked gene expression and adjacent gene expression to predict the cell types
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cell_type_model.to(device)
# Train the model
epochs = 100

multiclass_accuracy = MulticlassAccuracy()
multiclass_auprc = MulticlassAUPRC(num_classes = 24)
multiclass_confusion_matrix = MulticlassConfusionMatrix(num_classes = 24)

for epoch in tqdm(range(epochs)):

    cell_type_model.train()
    all_accuracy = []
    
    running_loss = 0.0
    for data in tqdm(train_loader):
        
        inputs = data.x.float().to(device)
        labels = data.cell_type_masked.long().to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward + backward + optimize
        outputs = cell_type_model(inputs, data.edge_index.to(device))
        #print(outputs.shape)
        loss = criterion(outputs[data.mask], labels)
        #print(outputs[data.mask].shape)
        loss.backward()
        optimizer.step()

        # Compute additional metrics
        predicted = outputs[data.mask].cpu()
        multiclass_accuracy.update(predicted, labels.cpu())
        multiclass_auprc.update(predicted, labels.cpu())
        multiclass_confusion_matrix.update(predicted, labels.cpu())

        # Print statistics
        #accuracy = (predicted == labels.cpu()).sum().item() / len(labels)
        #all_accuracy.append(accuracy)
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.3f}')
    print('Multiclass Accuracy:', multiclass_accuracy.compute())
    print('Multiclass AUPRC:', multiclass_auprc.compute())
    #print('Multiclass Confusion Matrix:', multiclass_confusion_matrix.compute())

    # Reset metrics after each epoch if needed
    #multiclass_accuracy.reset()
    #multiclass_auprc.reset()
    #multiclass_confusion_matrix.reset()

print('Training finished!')


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 1, Loss: 2.057
Multiclass Accuracy: tensor(0.1688)
Multiclass AUPRC: tensor(0.0421)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 2, Loss: 2.035
Multiclass Accuracy: tensor(0.1790)
Multiclass AUPRC: tensor(0.0423)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 3, Loss: 2.037
Multiclass Accuracy: tensor(0.1802)
Multiclass AUPRC: tensor(0.0423)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 4, Loss: 2.028
Multiclass Accuracy: tensor(0.1831)
Multiclass AUPRC: tensor(0.0424)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 5, Loss: 2.028
Multiclass Accuracy: tensor(0.1832)
Multiclass AUPRC: tensor(0.0424)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 6, Loss: 2.021
Multiclass Accuracy: tensor(0.1852)
Multiclass AUPRC: tensor(0.0425)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 7, Loss: 2.023
Multiclass Accuracy: tensor(0.1866)
Multiclass AUPRC: tensor(0.0426)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 8, Loss: 2.014
Multiclass Accuracy: tensor(0.1881)
Multiclass AUPRC: tensor(0.0426)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 9, Loss: 2.009
Multiclass Accuracy: tensor(0.1892)
Multiclass AUPRC: tensor(0.0427)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 10, Loss: 2.001
Multiclass Accuracy: tensor(0.1908)
Multiclass AUPRC: tensor(0.0427)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 11, Loss: 2.004
Multiclass Accuracy: tensor(0.1914)
Multiclass AUPRC: tensor(0.0428)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 12, Loss: 1.992
Multiclass Accuracy: tensor(0.1927)
Multiclass AUPRC: tensor(0.0429)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 13, Loss: 1.988
Multiclass Accuracy: tensor(0.1940)
Multiclass AUPRC: tensor(0.0430)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 14, Loss: 1.972
Multiclass Accuracy: tensor(0.1961)
Multiclass AUPRC: tensor(0.0431)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 15, Loss: 1.967
Multiclass Accuracy: tensor(0.1980)
Multiclass AUPRC: tensor(0.0432)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 16, Loss: 1.964
Multiclass Accuracy: tensor(0.1995)
Multiclass AUPRC: tensor(0.0434)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 17, Loss: 1.952
Multiclass Accuracy: tensor(0.2014)
Multiclass AUPRC: tensor(0.0435)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 18, Loss: 1.945
Multiclass Accuracy: tensor(0.2034)
Multiclass AUPRC: tensor(0.0437)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 19, Loss: 1.930
Multiclass Accuracy: tensor(0.2055)
Multiclass AUPRC: tensor(0.0439)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 20, Loss: 1.922
Multiclass Accuracy: tensor(0.2078)
Multiclass AUPRC: tensor(0.0440)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 21, Loss: 1.902
Multiclass Accuracy: tensor(0.2104)
Multiclass AUPRC: tensor(0.0443)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 22, Loss: 1.893
Multiclass Accuracy: tensor(0.2129)
Multiclass AUPRC: tensor(0.0445)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 23, Loss: 1.871
Multiclass Accuracy: tensor(0.2156)
Multiclass AUPRC: tensor(0.0448)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 24, Loss: 1.855
Multiclass Accuracy: tensor(0.2184)
Multiclass AUPRC: tensor(0.0451)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 25, Loss: 1.849
Multiclass Accuracy: tensor(0.2213)
Multiclass AUPRC: tensor(0.0454)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 26, Loss: 1.825
Multiclass Accuracy: tensor(0.2242)
Multiclass AUPRC: tensor(0.0457)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 27, Loss: 1.808
Multiclass Accuracy: tensor(0.2272)
Multiclass AUPRC: tensor(0.0461)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 28, Loss: 1.798
Multiclass Accuracy: tensor(0.2301)
Multiclass AUPRC: tensor(0.0465)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 29, Loss: 1.787
Multiclass Accuracy: tensor(0.2330)
Multiclass AUPRC: tensor(0.0469)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 30, Loss: 1.768
Multiclass Accuracy: tensor(0.2362)
Multiclass AUPRC: tensor(0.0473)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 31, Loss: 1.746
Multiclass Accuracy: tensor(0.2393)
Multiclass AUPRC: tensor(0.0477)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 32, Loss: 1.732
Multiclass Accuracy: tensor(0.2425)
Multiclass AUPRC: tensor(0.0482)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 33, Loss: 1.715
Multiclass Accuracy: tensor(0.2456)
Multiclass AUPRC: tensor(0.0487)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 34, Loss: 1.687
Multiclass Accuracy: tensor(0.2490)
Multiclass AUPRC: tensor(0.0492)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 35, Loss: 1.660
Multiclass Accuracy: tensor(0.2524)
Multiclass AUPRC: tensor(0.0498)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 36, Loss: 1.636
Multiclass Accuracy: tensor(0.2561)
Multiclass AUPRC: tensor(0.0504)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 37, Loss: 1.622
Multiclass Accuracy: tensor(0.2596)
Multiclass AUPRC: tensor(0.0511)


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch 38, Loss: 1.593
Multiclass Accuracy: tensor(0.2633)
Multiclass AUPRC: tensor(0.0518)


  0%|          | 0/9 [00:00<?, ?it/s]

In [98]:
def all_elements_same(temp):
    if not temp:
        return True  # An empty list is considered to have all elements the same
    first_element = temp[0]
    return all(element == first_element for element in temp)


result = all_elements_same(temp)
print(result)  # Output: True

True


In [None]:
from tqdm.auto import tqdm
from torch import optim
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Train the model
epochs = 10
for epoch in tqdm(range(epochs)):

    model.train()
    all_accuracy = []

    running_loss = 0.0
    for data in tqdm(train_loader, leave=False):
        img = data.image[0]
        sub_adata = dataset_constructor.adata[dataset_constructor.adata.obs["section"] == img]
        inputs = torch.tensor(sub_adata.X.toarray(), dtype=torch.float).to(device)
        # Convert the cell_type_masked to a flat list of labels
        labels = torch.tensor([label_mapping[value] for value in data.cell_type_masked[0]], dtype=torch.long).to(device) # Make sure to use long data type

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward + backward + optimize
        outputs = model(inputs, data.edge_index.to(device))

        loss = criterion(outputs[data.mask].float(), labels) # Removed the one-hot encoding part

        loss.backward()
        optimizer.step()

        # Print statistics
        accuracy = (outputs.argmax(dim=1) == labels).sum().item() / len(labels)
        all_accuracy.append(accuracy)
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.3f}, accuracy: {sum(all_accuracy) / len(all_accuracy):.3f}')

print('Training finished!')
