In [1]:
import matplotlib.pyplot as plt
import torch
from spatialSSL.Dataloader import FullImageDatasetConstructor
from spatialSSL.Utils import split_dataset
from spatialSSL.Training import train
from spatialSSL.Training import train_epoch
from spatialSSL.Testing import test
from spatialSSL.Dataset import InMemoryGraphDataset
from torch import nn, optim, Tensor
import torch
from torch_geometric.nn import GCNConv,GATConv
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import r2_score
from torch.nn import LeakyReLU, Dropout
import time
from torch.utils.checkpoint import checkpoint
import numpy as np


	geopandas.options.use_pygeos = True

If you intended to use PyGEOS, set the option to False.
  _check_geopandas_using_shapely()


In [2]:
import torch.nn.functional as F

In [3]:
class PretrainedGAT_2(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout_rate=0.5):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)
        self.dropout = Dropout(dropout_rate)
        
        self.act = nn.LeakyReLU()

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = self.dropout(self.act(self.conv1(x, edge_index)))
        x = checkpoint(self.conv2, x, edge_index)
        x = self.act(self.conv3(x, edge_index))  # Typically, dropout is not applied to the final layer.
        return x
    
class GAT_2(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_label, dropout_rate=0.5):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels)
        self.dropout = Dropout(dropout_rate)
        self.lin1 = nn.Linear(hidden_channels, num_label)  # Use num_label instead of 20
        self.act = nn.LeakyReLU()


    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = self.dropout(self.act(self.conv1(x, edge_index)))
        x = self.conv3(x, edge_index)
        x = self.lin1(x)
        #x = F.softmax(x, dim=1)

        return x


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_path = "./models/model_fillomgGAT2.pt"

# Load the pretrained model without the linear layer
pretrained_model = PretrainedGAT_2(550, 550, 550).to(device)
pretrained_model.load_state_dict(torch.load(model_path))
pretrained_model.eval()

PretrainedGAT_2(
  (conv1): GATConv(550, 550, heads=1)
  (conv2): GATConv(550, 550, heads=1)
  (conv3): GATConv(550, 550, heads=1)
  (dropout): Dropout(p=0.5, inplace=False)
  (act): LeakyReLU(negative_slope=0.01)
)

In [5]:
# Copy the weights from the pretrained model to the new model with the linear layer
model = GAT_2(550, 550, 550, 24).to(device)
model.conv1.load_state_dict(pretrained_model.conv1.state_dict())
#model.conv2.load_state_dict(pretrained_model.conv2.state_dict())
model.conv3.load_state_dict(pretrained_model.conv3.state_dict())

<All keys matched successfully>

In [None]:
file_path = "./data/subset_6img_atlas_brain.h5ad"

# Create the dataloader
dataset_constructor = FullImageDatasetConstructor(file_path=file_path, image_col="section",
                                                                     label_col="class_label", include_label=False,
                                                                     radius=40, node_level=1)

# Load the data
dataset_constructor.load_data()

# Construct the graph
dataset = dataset_constructor.construct_graph()


total_cells = len(dataset_constructor.adata)
print(total_cells)

Constructing Graphs:   0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
train_loader, val_loader, test_loader = split_dataset(dataset,split_percent=(0.2, 0.2, 0.2), batch_size=1)

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
#scheduler = StepLR(optimizer, step_size=80, gamma=0.1) # learning rate scheduler


In [None]:
cell_type_labels = np.unique(dataset_constructor.adata.obs["class_label"])



In [None]:
from sklearn.preprocessing import LabelEncoder
# Reshaping the labels to match the required input shape for the encoder
cell_type_labels_reshaped = cell_type_labels.reshape(-1, 1)

# Creating an instance of OneHotEncoder
encoder = LabelEncoder()

# Fit the encoder and transform the unique class labels
label_encoded = encoder.fit_transform(cell_type_labels_reshaped)

# Creating a dictionary to map the unique class labels to their one-hot encoded values
label_mapping = {label: code for label, code in zip(cell_type_labels, label_encoded)}
# Now, label_mapping contains the mapping between unique class labels and one-hot encoded values

In [None]:
label_mapping

In [None]:
for x in train_loader:
    print(x)


In [None]:
from tqdm.auto import tqdm
from torch import optim



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Train the model
epochs = 100
temp = []
for epoch in tqdm(range(epochs)):

    model.train()
    all_accuracy = []
    
    running_loss = 0.0
    for data in tqdm(train_loader):
        img = data.image[0]
        #sub_adata = dataset_constructor.adata[dataset_constructor.adata.obs["section"] == img].copy()
        inputs = torch.tensor(data.y, dtype=torch.float)
        #print(inputs.shape)
        #inputs = inputs[data.mask]
        # Convert the cell_type_masked to a flat list of labels
        labels = torch.tensor([label_mapping[value] for value in data.cell_type_masked[0]])

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward + backward + optimize
        outputs = model(inputs.to(device), data.edge_index.to(device))
        temp.extend(outputs.detach().cpu().tolist())
        loss = criterion(outputs.float().cpu(), torch.tensor(labels).long().cpu())

        loss.backward()
        optimizer.step()

        # Print statistics
        accuracy = (outputs.argmax(dim=1).cpu() == labels.cpu()).sum().item() / len(labels)
        all_accuracy.append(accuracy)
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loass: {running_loss / len(train_loader):.3f}, accuracy: {sum(all_accuracy) / len(all_accuracy):.3f}')

print('Training finished!')

In [98]:
def all_elements_same(temp):
    if not temp:
        return True  # An empty list is considered to have all elements the same
    first_element = temp[0]
    return all(element == first_element for element in temp)


result = all_elements_same(temp)
print(result)  # Output: True

True


In [None]:
from tqdm.auto import tqdm
from torch import optim
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Train the model
epochs = 10
for epoch in tqdm(range(epochs)):

    model.train()
    all_accuracy = []

    running_loss = 0.0
    for data in tqdm(train_loader, leave=False):
        img = data.image[0]
        sub_adata = dataset_constructor.adata[dataset_constructor.adata.obs["section"] == img]
        inputs = torch.tensor(sub_adata.X.toarray(), dtype=torch.float).to(device)
        # Convert the cell_type_masked to a flat list of labels
        labels = torch.tensor([label_mapping[value] for value in data.cell_type_masked[0]], dtype=torch.long).to(device) # Make sure to use long data type

        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward + backward + optimize
        outputs = model(inputs, data.edge_index.to(device))

        loss = criterion(outputs[data.mask].float(), labels) # Removed the one-hot encoding part

        loss.backward()
        optimizer.step()

        # Print statistics
        accuracy = (outputs.argmax(dim=1) == labels).sum().item() / len(labels)
        all_accuracy.append(accuracy)
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.3f}, accuracy: {sum(all_accuracy) / len(all_accuracy):.3f}')

print('Training finished!')
