In [1]:
import matplotlib.pyplot as plt
import torch
import spatialSSL

file_path = "../example_files/img_119670929_1199650932.h5ad"

# Create the dataloader
dataset_constructor = spatialSSL.Dataloader.EgoNetDatasetConstructor(file_path=file_path, image_col="section",
                                                                     label_col="class_label", include_label=False,
                                                                     radius=20, node_level=2)

# Load the data
dataset_constructor.load_data()

# Construct the graph
#dataset = dataset_constructor.construct_graph()

dataset = torch.load("../dataset/dataset.pt")
total_cells = len(dataset_constructor.adata)
print(len(dataset), "of", total_cells)

52528 of 55530


In [67]:
from torch import Tensor, nn
from torch.nn import Linear
from torch_geometric.nn import GCNConv, GATConv

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin1 = nn.Linear(hidden_channels, 20)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.lin1(x).softmax(dim=1)
        return x

In [68]:
dataset_constructor.adata.obs.class_label.unique()#.cat.codes.values

['MY GABA', 'MY Glut', 'Astro-Epen', 'Oligo', 'P GABA', ..., 'CNU-HYa GABA', 'MB Glut', 'CNU-HYa Glut', 'TH Glut', 'MB Dopa']
Length: 20
Categories (20, object): ['Astro-Epen', 'CB GABA', 'CB Glut', 'CNU-HYa GABA', ..., 'P GABA', 'P Glut', 'TH Glut', 'Vascular']

In [69]:
for batch in dataset:
    print(batch.x)
    break

tensor([    0,  3600, 17561])


In [70]:
train_loader, val_loader, test_loader = spatialSSL.Utils.split_dataset(dataset=dataset, split_percent=(0.8, 0.1, 0.1), batch_size = 64)


In [71]:
for batch in train_loader:
    print(batch.x)
    break

tensor([34235, 34489, 34503, 34521, 34530, 34554, 34649, 34664, 34679, 34727,
        34732, 34753, 34763, 34764, 34778, 34847, 35176, 35265, 36053, 36103,
        36169, 37044, 37189, 37190, 39536, 39620, 39662, 30556, 36370, 37042,
        37079, 37201, 48195, 39845, 40944, 44569, 46698, 46710, 47602, 47814,
         4937,  6217,  6469,  6473,  6541, 14017, 14999, 16382, 17000, 17633,
        17881, 17890, 17986, 19320, 23212, 24141, 24472, 27006, 30318, 30732,
        30927, 30928, 39006, 40183, 40457,  5181,  5203,  5738,  6243, 14117,
        15352, 15690, 17220, 17236, 17257, 17327, 17339, 17833, 17922, 18457,
        23295, 11980, 12434, 24937,   847,  1648,  2971,  7222, 11353, 21332,
        21699, 10348, 10514, 11298, 11302, 23570, 36964, 45937, 49323, 54208,
        27826, 33314, 33457, 33988, 41861, 41953, 42013, 42501, 42687, 43800,
        45205, 45535, 45786, 46014, 46081, 46310, 46580, 46868, 47224, 47707,
         7448,  8209,  9729, 12735, 16244, 16996, 17741, 18274, 

In [72]:
from tqdm.auto import tqdm
from torch import optim

net = GCN(in_channels=550, hidden_channels=550, out_channels=550)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
# Train the model
epochs = 10
for epoch in tqdm(range(epochs)):

    net.train()
    all_accuracy = []

    running_loss = 0.0
    for data in tqdm(train_loader, leave=False):
        #inputs, _ = data.
        input = torch.tensor(dataset_constructor.adata.X[data.x].toarray(), dtype=torch.double).to(device).float()

        labels = torch.tensor(dataset_constructor.adata[data.x.numpy()].obs.class_label.cat.codes.values).to(device).long()
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = net(input, data.edge_index)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        accuracy = (outputs.argmax(dim=1) == labels).sum().item() / len(labels)
        all_accuracy.append(accuracy)
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.3f}, accuracy: {sum(all_accuracy) / len(all_accuracy):.3f}')

print('Training finished!')


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/657 [00:00<?, ?it/s]

Epoch 1, Loss: 2.650, accuracy: 0.449


  0%|          | 0/657 [00:00<?, ?it/s]

Epoch 2, Loss: 2.577, accuracy: 0.499


  0%|          | 0/657 [00:00<?, ?it/s]

Epoch 3, Loss: 2.546, accuracy: 0.543


  0%|          | 0/657 [00:00<?, ?it/s]

Epoch 4, Loss: 2.511, accuracy: 0.574


  0%|          | 0/657 [00:00<?, ?it/s]

Epoch 5, Loss: 2.505, accuracy: 0.577


  0%|          | 0/657 [00:00<?, ?it/s]

Epoch 6, Loss: 2.501, accuracy: 0.580


  0%|          | 0/657 [00:00<?, ?it/s]

Epoch 7, Loss: 2.496, accuracy: 0.581


  0%|          | 0/657 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [82]:
from statistics import mean

net = GCN(in_channels=550, hidden_channels=550, out_channels=550)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
# Train the model
# training loop
for epoch in tqdm(range(epochs)):
    net.train()

    train_accs = []
    running_loss = 0.0

    for data in train_loader:
        #inputs, _ = data.
        input = torch.tensor(dataset_constructor.adata.X[data.x].toarray(), dtype=torch.double).to(device).float()

        labels = torch.tensor(dataset_constructor.adata[data.x.numpy()].obs.class_label.cat.codes.values).to(device).long()
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = net(input, data.edge_index)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        accuracy = (outputs.argmax(dim=1) == labels).sum().item() / len(labels)
        train_accs.append(accuracy)
        running_loss += loss.item()

    running_loss = running_loss / len(train_loader)
       # Validation
    net.eval()
    val_accs = []
    with torch.no_grad():
        val_loss = 0.0
        for data in val_loader:

            input = torch.tensor(dataset_constructor.adata.X[data.x].toarray(), dtype=torch.double).to(device).float()
            labels = torch.tensor(dataset_constructor.adata[data.x.numpy()].obs.class_label.cat.codes.values).to(device).long()
            outputs = net(input, data.edge_index)
            loss = criterion(outputs, labels)

            accuracy = (outputs.argmax(dim=1) == labels).sum().item() / len(labels)
            val_accs.append(accuracy)
            val_loss += loss.item()

        #val_accuracy = 100 * correct / total
        avg_val_loss = val_loss / len(val_loader)


    print(f'Epoch: {epoch+1}/{epochs}, Train Loss: {running_loss:.4f}, Val Loss: {avg_val_loss:.4f} Train Acc: {mean(train_accs):.4f}, Val Acc: {mean(val_accs):.4f}')



    #print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.3f}, accuracy: {sum(all_accuracy) / len(all_accuracy):.3f}')

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1/10, Train Loss: 2.6559, Val Loss: 2.5806 Train Acc: 0.4355, Val Acc: 0.4947
Epoch: 2/10, Train Loss: 2.5610, Val Loss: 2.5548 Train Acc: 0.5220, Val Acc: 0.5259
Epoch: 3/10, Train Loss: 2.5454, Val Loss: 2.5472 Train Acc: 0.5343, Val Acc: 0.5315
Epoch: 4/10, Train Loss: 2.5385, Val Loss: 2.5426 Train Acc: 0.5409, Val Acc: 0.5375
Epoch: 5/10, Train Loss: 2.5340, Val Loss: 2.5360 Train Acc: 0.5460, Val Acc: 0.5485
Epoch: 6/10, Train Loss: 2.5252, Val Loss: 2.5298 Train Acc: 0.5581, Val Acc: 0.5496
Epoch: 7/10, Train Loss: 2.5189, Val Loss: 2.5217 Train Acc: 0.5625, Val Acc: 0.5596
Epoch: 8/10, Train Loss: 2.5092, Val Loss: 2.5193 Train Acc: 0.5714, Val Acc: 0.5612
Epoch: 9/10, Train Loss: 2.5086, Val Loss: 2.5135 Train Acc: 0.5707, Val Acc: 0.5700
Epoch: 10/10, Train Loss: 2.4855, Val Loss: 2.4903 Train Acc: 0.5996, Val Acc: 0.5934
