In [1]:
from spatialSSL.Dataloader import EgoNetDataloader, FullImageConstracter
from spatialSSL.Utils import split_dataset



	geopandas.options.use_pygeos = True

If you intended to use PyGEOS, set the option to False.
  _check_geopandas_using_shapely()


In [2]:
# Create an instance of Full_image_dataloader

#file_path = "./data/img_119670929.h5ad"
file_path = "./data/subset_6img_atlas_brain.h5ad"
data_constracter = FullImageConstracter(file_path=file_path, image_col="section", label_col="class_id_label", include_label=False, radius=20,node_level = 1, batch_size=1)

In [3]:
# Load the data
data_constracter.load_data()

# Construct the graph
graph_list = data_constracter.construct_graph()

Constructing Graphs:   0%|          | 0/6 [00:00<?, ?it/s]

In [4]:
for x in graph_list:
    print(x)
    break

Data(x=[26230, 550], edge_index=[2, 102942], y=[5246, 550], mask=[26230], cell_type=['22 MY GABA', '22 MY GABA', '22 MY GABA', '22 MY GABA', '22 MY GABA', ..., '31 Vascular', '31 Vascular', '31 Vascular', '31 Vascular', '31 Vascular']
Length: 26230
Categories (15, object): ['11 HY GABA', '15 HY Glut', '17 P Glut', '18 MB-HB Sero', ..., '30 OEG', '31 Vascular', '32 Immune', '33 LQ'], cell_type_masked=['24 CB GABA', '19 MY Glut', '19 MY Glut', '28 Astro-Epen', '25 CB Glut', ..., '25 CB Glut', '19 MY Glut', '29 Oligo', '25 CB Glut', '25 CB Glut']
Length: 5246
Categories (15, object): ['11 HY GABA', '15 HY Glut', '17 P Glut', '18 MB-HB Sero', ..., '30 OEG', '31 Vascular', '32 Immune', '33 LQ'], image='1199650929')


In [5]:
train_loader, test_loader, val_loader = split_dataset(graph_list,split_percent=(0.6, 0.2, 0.2), batch_size=1)

In [6]:
for x in train_loader:
    print(list(x.mask).count(False))

    break

9492


In [7]:
# Print out the size of each set to verify
print(f"Train size: {len(train_loader.dataset)}")
print(f"Validation size: {len(val_loader.dataset)}")
print(f"Test size: {len(test_loader.dataset)}")

Train size: 4
Validation size: 1
Test size: 1


In [8]:
from torch import nn, optim, Tensor
import torch
from torch_geometric.nn import GCNConv
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import r2_score
import time

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.linear = nn.Linear(out_channels, 550)
        self.act = nn.LeakyReLU()

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = self.act(self.conv1(x, edge_index))
        x = self.act(self.conv2(x, edge_index))
        x = self.linear(x)
        return x

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the model
model = GCN(550, 550, 550).to(device) # in_channels is set to 100 as an example. Please replace it with your actual feature size.

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1) # learning rate scheduler

num_epochs = 50
patience = 5

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    targets_list = []
    outputs_list = []

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        outputs = model(data.x.float(), data.edge_index.long())
        loss = criterion(outputs[~data.mask], data.y.float())
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs
        targets_list.append(data.y.cpu().detach())
        outputs_list.append(outputs[~data.mask].cpu().detach())

    return total_loss / len(loader.dataset), r2_score(torch.cat(targets_list).numpy(), torch.cat(outputs_list).numpy())

def validate_one_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0
    targets_list = []
    outputs_list = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            outputs = model(data.x.float(), data.edge_index.long())
            loss = criterion(outputs[~data.mask], data.y.float())

            total_loss += loss.item() * data.num_graphs
            targets_list.append(data.y.cpu())
            outputs_list.append(outputs[~data.mask].cpu())

    return total_loss / len(loader.dataset), r2_score(torch.cat(targets_list).numpy(), torch.cat(outputs_list).numpy())



In [9]:

# Training loop
best_val_loss = float('inf')
best_epoch = 0
epochs_no_improve = 0

for epoch in range(num_epochs):
    start_time = time.time()
    train_loss, train_r2 = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_r2 = validate_one_epoch(model, val_loader, criterion)
    scheduler.step() # Decrease learning rate by scheduler

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        epochs_no_improve += 1
        if epochs_no_improve == patience:
            print('Early stopping!')
            break

    print(f"Epoch {epoch+1}/{num_epochs}, train loss: {train_loss:.4f}, train r2: {train_r2:.4f},  val loss: {val_loss:.4f}, val r2: {val_r2:.4f}, Time: {time.time()-start_time}s")

print(f"Best val loss: {best_val_loss:.4f}, at epoch {best_epoch+1}")

Epoch 1/50, train loss: 0.9219, train r2: -0.3284,  val loss: 0.7252, val r2: -0.3058, Time: 10.841485977172852s
Epoch 2/50, train loss: 0.6982, train r2: -0.1905,  val loss: 0.6613, val r2: -0.1409, Time: 10.559338569641113s
Epoch 3/50, train loss: 0.6495, train r2: -0.1078,  val loss: 0.6369, val r2: -0.0996, Time: 10.808122873306274s
Epoch 4/50, train loss: 0.6142, train r2: -0.0552,  val loss: 0.5991, val r2: -0.0425, Time: 11.104062557220459s
Epoch 5/50, train loss: 0.5833, train r2: -0.0165,  val loss: 0.5787, val r2: -0.0132, Time: 11.197389364242554s
Epoch 6/50, train loss: 0.5636, train r2: 0.0094,  val loss: 0.5619, val r2: 0.0095, Time: 11.517719984054565s
Epoch 7/50, train loss: 0.5482, train r2: 0.0282,  val loss: 0.5497, val r2: 0.0234, Time: 11.639735221862793s
Epoch 8/50, train loss: 0.5374, train r2: 0.0434,  val loss: 0.5397, val r2: 0.0376, Time: 11.201942443847656s
Epoch 9/50, train loss: 0.5280, train r2: 0.0558,  val loss: 0.5302, val r2: 0.0508, Time: 11.33341288