In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import re
from scipy.spatial import Voronoi, Delaunay
from sklearn.model_selection import StratifiedShuffleSplit

import torch
from torch_geometric.loader import DataLoader
from torch.nn.functional import cross_entropy
from torch.optim import Adam

In [None]:
os.makedirs("images", exist_ok = True)
os.makedirs("results", exist_ok = True)

# GRAPHICAL CONSTRUCTION

In [None]:
from source import graph_utilities

In [None]:
edge_representations = {}
center_representations = {}

for root, directories, files in os.walk("../celesta/output"):
    for file in files:
        if "final_cell_type" in file:
            path_parts = root.split(os.path.sep)
            
            regionalization = path_parts[3]
            nodal_status = path_parts[4]
            sample_identifier = path_parts[5]
            
            assignments = pd.read_csv(os.path.join(root, file)).dropna().reset_index(drop = True)
            
            voronoi = Voronoi(assignments[["X", "Y"]].values)
            delaunay = Delaunay(assignments[["X", "Y"]].values)
            
            graph = graph_utilities.construct_sample_graph(delaunay, assignments)
            
            representation = {
                "regionalization": regionalization,
                "nodal_status": nodal_status,
                "assignments": assignments,
                "voronoi": voronoi,
                "delaunay": delaunay,
                "graph": graph
            }
            
            (center_representations if regionalization == "center" else edge_representations)[sample_identifier] = representation

here, we visualize an example of the voronoi and delaunay diagrams constructed across an edge sample in its entirety

In [None]:
random.seed(42)

edge_sample_name, edge_sample = random.choice(list(edge_representations.items()))

fig, ax = plt.subplots(figsize = (15, 10))

graph_utilities.plot_voronoi_diagram(edge_sample["voronoi"], edge_sample["assignments"], ax)
graph_utilities.plot_delaunay_triangulation(edge_sample["delaunay"], edge_sample["assignments"], ax)

ax.set_title(f'{edge_sample_name} ({edge_sample["regionalization"]} / {edge_sample["nodal_status"]})')

plt.tight_layout()

In [None]:
sample = True

if sample:
    random.seed(42)

    sampled_edge_representations = dict(random.sample(list(edge_representations.items()), 6))
    sampled_center_representations = dict(random.sample(list(center_representations.items()), 6))

else:
    sampled_edge_representations = edge_representations
    sampled_center_representations = center_representations

In [None]:
edge_microenvironments = graph_utilities.construct_microenvironments(sampled_edge_representations)
center_microenvironments = graph_utilities.construct_microenvironments(sampled_center_representations)

here, we visualize a microenvironment from an edge sample, defined as the 3 hop neighborhood of any cell

In [None]:
random.seed(224)

edge_sample_name, edge_sample_microenvironments = random.choice(list(edge_microenvironments.items()))
edge_center_node = random.choice(range(len(edge_sample_microenvironments)))

fig, ax = plt.subplots(figsize = (10, 10))

graph_utilities.plot_microenvironment(edge_sample_microenvironments[edge_center_node], edge_center_node, ax)

ax.set_title(f"microenvironment from sample: {edge_sample_name}")
plt.tight_layout()

plt.savefig(f"images/{edge_sample_name}_microenvironment.png")

# MODELS

In [None]:
from source import data
from source import models
from source import training_and_evaluation

In [None]:
mapping = {
    "unknown": 0,
    "epithelial cell (cytokeratin+)": 1,
    "endothelial cell (CD31+)": 2,
    "fibroblast (FAP+)": 3,
    "stromal cell (CD90+)": 4,
    "mesenchymal cell (podoplanin+)": 5,
    "T cell (CD3+)": 6,
    "B cell (CD20+)": 7,
    "granulocyte (CD15+)": 8,
    "dendritic cell (CD11c+)": 9,
    "macrophage (CD68+)": 10,
    "macrophage (CD163+)": 11,
    "macrophage (CD68+ CD163+)": 12,
    "macrophage (CD68+ CD163-)": 13,
    "fibroblast (FAP+ CD90+)": 14,
    "fibroblast (FAP+ CD90-)": 15,
    "cytotoxic T cell (CD8+)": 16,
    "helper T cell (CD4+)": 17,
    "regulatory T cell (CD4+ FOXP3+)": 18
}

In [None]:
edge_microenvironments = data.prepare_data(edge_microenvironments, mapping, 1)
center_microenvironments = data.prepare_data(center_microenvironments, mapping, 0)

all_microenvironments = edge_microenvironments + center_microenvironments

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_dim = 19
output_dim = 2
hidden_dim = 64

model_type = "GCN"
model = models.TumorGCNClassifier(input_dim, hidden_dim, output_dim).to(device)

optimizer = Adam(model.parameters(), lr = 0.001)

features = [microenvironment.x for microenvironment in all_microenvironments]
labels = [microenvironment.y.item() for microenvironment in all_microenvironments]

splitter = StratifiedShuffleSplit(n_splits = 1, test_size = 0.2, random_state = 42)

for train_index, test_index in splitter.split(features, labels):
    train_data = [all_microenvironments[i] for i in train_index]
    test_data = [all_microenvironments[i] for i in test_index]

train_loader = DataLoader(train_data, batch_size = 32, shuffle = True)
test_loader = DataLoader(test_data, batch_size = 32, shuffle = False)

In [None]:
epochs = 30

train_losses_list = []
train_metrics_list = []
test_metrics_list = []

for epoch in range(1, epochs + 1):
    train_loss, train_metrics = training_and_evaluation.train_epoch(model, train_loader, optimizer, 
                                                                        cross_entropy, device)
    test_metrics = training_and_evaluation.evaluate_epoch(model, test_loader, device)
    
    train_losses_list.append(train_loss)
    train_metrics_list.append(train_metrics)
    test_metrics_list.append(test_metrics)

    print(f'epoch {epoch}, loss: {train_loss:.4f}, test accuracy: {test_metrics["accuracy"]:.4f}')    

torch.save(model.state_dict(), f"results/{model_type}.pth")

with open(f"results/{model_type}_training_losses.txt", "w") as loss_file:
    for epoch, loss in enumerate(train_losses_list, 1):
        epoch_summary = {"epoch": epoch, "train_loss": loss}
        loss_file.write(f"{epoch_summary}\n")

with open(f"results/{model_type}_training_metrics.txt", "w") as train_file:
    for epoch, metrics in enumerate(train_metrics_list, 1):
        epoch_summary = {"epoch": epoch, **metrics}
        train_file.write(f"{epoch_summary}\n")

with open(f"results/{model_type}_evaluation_metrics.txt", "w") as evaluation_file:
    for epoch, metrics in enumerate(test_metrics_list, 1):
        epoch_summary = {"epoch": epoch, **metrics}
        evaluation_file.write(f"{epoch_summary}\n")

# INTERPRETABILITY

In [None]:
from source import interpretability

In [None]:
model_types = ["GCN",
               "GIN",
               "GAT"]

training_loss_files = ["results/GCN_training_losses.txt",
                       "results/GIN_training_losses.txt",
                       "results/GAT_training_losses.txt"]

In [None]:
interpretability.plot_joint_training_losses(training_loss_files, model_types)

In [None]:
fig, axes = plt.subplots(1, 3, figsize = (20, 5))

for ax, model_type in zip(axes, model_types):
    interpretability.plot_performance_measures(f"results/{model_type}_training_metrics.txt",
                                               f"results/{model_type}_evaluation_metrics.txt",
                                               model_type, metric = "f1", ax = ax)

plt.tight_layout()

plt.savefig("images/performance_measures_comparison.png")

In [None]:
model = models.TumorGATClassifier(input_dim, hidden_dim, output_dim, heads = 3).to(device)
model.load_state_dict(torch.load(f"results/{model_types[-1]}.pth"))

embeddings, probability_predictions, cell_type_proportions = interpretability.extract_embeddings(model, train_loader, device)

In [None]:
interpretability.visualize_embeddings(embeddings,
                                      probability_predictions,
                                      cell_type_proportions,
                                      model_types[-1],
                                      mapping, 5)