# Training GCN

Train a Graph Neural Network with the histology + gene information.

We are going to create a brain layer classifier using [Torch Geometric's GCN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GCN.html?highlight=gcn#torch-geometric-nn-models-gcn)

This notebook shows the processing steps (at a lower level, what goes below the wrappers) taken from acquiring the data to creating the dataloader and using those in a training loop.

## Loading the data

For a detailed exploration and analysis of what the data actually contains, visit the `DataAnalysis` notebook, located in the same directory as this one.

### Downloading the data

In [1]:
%%sh
./dataset/getdata.sh

~/projects/GNNCellClassification/dataset ~/projects/GNNCellClassification
Downloading data if needed...
Don't download data: Both data and images already exist
~/projects/GNNCellClassification


### Actual loading + preprocessing

In [2]:
import sys, os
sys.path.append(os.path.abspath("src"))

In [3]:
%load_ext autoreload
%autoreload 2
import importlib
import preprocess
import load_data

data_dir, img_dir, graph_dir = "dataset/data", "dataset/images", "out/graphs"
ann_data, histology_imgs = preprocess.main(data_dir, img_dir, graph_dir)

Loading AnnData for sample 151676 …
Loading AnnData for sample 151669 …
Loading AnnData for sample 151507 …
Loading AnnData for sample 151508 …
Loading AnnData for sample 151672 …
Loading AnnData for sample 151670 …
Loading AnnData for sample 151673 …
Loading AnnData for sample 151675 …
Loading AnnData for sample 151510 …
Loading AnnData for sample 151671 …
Loading AnnData for sample 151674 …
Loading AnnData for sample 151509 …
Loading Image for sample 151676 …
Loading Image for sample 151669 …
Loading Image for sample 151507 …
Loading Image for sample 151508 …
Loading Image for sample 151672 …
Loading Image for sample 151670 …
Loading Image for sample 151673 …
Loading Image for sample 151675 …
Loading Image for sample 151510 …
Loading Image for sample 151671 …
Loading Image for sample 151674 …
Loading Image for sample 151509 …
Creating Graphs for sample 151676 …
Calculating adj matrix using histology image...


[ WARN:0@8.482] global loadsave.cpp:848 imwrite_ Unsupported depth image for selected encoder is fallbacked to CV_8U.


Variances of color0, color1, color2 =  [0.00134047 0.00493318 0.00063244]
Var of x, y, z =  99580.40478290287 115934.31450257276 115934.31589490475
Max value:  1942.6682021906095
Creating Graphs for sample 151669 …
Calculating adj matrix using histology image...
Variances of color0, color1, color2 =  [0.00100299 0.00408081 0.00090281]
Var of x, y, z =  115055.61305404994 109785.12825823567 115055.62076952032
Max value:  1990.445338189566
Creating Graphs for sample 151507 …
Calculating adj matrix using histology image...
Variances of color0, color1, color2 =  [0.00058533 0.00256631 0.00049386]
Var of x, y, z =  119744.2901290638 140702.3686761846 140702.35921238275
Max value:  2613.207049095901
Creating Graphs for sample 151508 …
Calculating adj matrix using histology image...
Variances of color0, color1, color2 =  [0.00059817 0.00289116 0.00048598]
Var of x, y, z =  121008.39013112546 148803.96100264232 148803.96221282193
Max value:  2685.33787054901
Creating Graphs for sample 151672 …

## Features used for training

We are going to use the following features for training:

**Edge Features**:
- Spatial Connectivities` between spots
- Pixel `distance` (adjusted by color)

**Node Features**:
- `UMI` count (log)
- `Color` in the `neighbourhood` of the spot


## PyTorch Geometric's data structure

From the official [Documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GCN.html?highlight=gcn#torch-geometric-nn-models-gcn) , we can see that:

A single graph in PyG is described by an instance of `torch_geometric.data.Data`, which holds the following attributes by default:

- `data.x`: Node feature matrix with shape `[num_nodes, num_node_features]`

- `data.edge_index`: Graph connectivity in `COO` format with shape `[2, num_edges]` and type `torch.long`

- `data.edge_attr`: Edge feature matrix with shape `[num_edges, num_edge_features]`

- `data.y`: Target to train against (may have arbitrary shape), e.g., node-level targets of shape `[num_nodes, *]` or graph-level targets of shape `[1, *]`

- `data.pos`: Node position matrix with shape `[num_nodes, num_dimensions]`

## Creating the required data structures for training

We need to convert the data to what's required by PyTorch geometric.

In [4]:
import torch

We randomly select which patients will be used for training, which ones for validation and which ones for testing:

In [5]:
from dataloader import train_val_test_split


train_patients, val_patients, test_patients = train_val_test_split(ann_data=ann_data, seed=42)

print(f"train_patients: {train_patients}")
print(f"val_patients: {val_patients}")
print(f"test_patients: {test_patients}")

val_patients: ['151507', '151675']
train_patients: ['151674', '151669', '151676', '151672', '151508', '151509', '151673', '151671']
val_patients: ['151507', '151675']
test_patients: ['151670', '151510']


### data.edge_index
We need to transform to a `PyTorch` tensor in `COO` format.
Let's start with a reference patient:

In [6]:
ann_data['151676'].obsp['spatial_connectivities']

<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 20052 stored elements and shape (3460, 3460)>

In [7]:
type(ann_data['151676'].obsp['spatial_connectivities'])

scipy.sparse._csr.csr_matrix

In [8]:
coo_matrix = ann_data['151676'].obsp['spatial_connectivities'].tocoo()
coo_matrix

<COOrdinate sparse matrix of dtype 'float64'
	with 20052 stored elements and shape (3460, 3460)>

In [9]:
type(coo_matrix)

scipy.sparse._coo.coo_matrix

In [10]:
from dataloader import get_coo_connections

coo_connections = get_coo_connections(ann_data)
for patient, coo in coo_connections.items():
    print(f"Patient {patient}: {coo.shape}")

Patient 151676: (3460, 3460)
Patient 151669: (3661, 3661)
Patient 151507: (4226, 4226)
Patient 151508: (4384, 4384)
Patient 151672: (4015, 4015)
Patient 151670: (3498, 3498)
Patient 151673: (3639, 3639)
Patient 151675: (3592, 3592)
Patient 151510: (4634, 4634)
Patient 151671: (4110, 4110)
Patient 151674: (3673, 3673)
Patient 151509: (4789, 4789)


In [11]:
from dataloader import get_edge_indices

edge_indices = get_edge_indices(coo_connections)

for patient, index in edge_indices.items():
    print(f"{patient}: {edge_indices[patient].shape}")

151676: torch.Size([2, 20052])
151669: torch.Size([2, 21194])
151507: torch.Size([2, 24770])
151508: torch.Size([2, 25698])
151672: torch.Size([2, 23382])
151670: torch.Size([2, 20370])
151673: torch.Size([2, 21124])
151675: torch.Size([2, 20762])
151510: torch.Size([2, 27198])
151671: torch.Size([2, 24052])
151674: torch.Size([2, 21258])
151509: torch.Size([2, 28172])


### data.edge_attr
Edge feature matrix with shape `[num_edges, num_edge_features]`.
For now, only get the distances for the ones that are spatially connected.

In [12]:
from dataloader import get_edge_features

edge_features = get_edge_features(ann_data, edge_indices, graph_dir)
for patient in ann_data.keys():
    print(f"{patient}: {edge_features[patient].shape}")

151676: torch.Size([20052, 1])
151669: torch.Size([21194, 1])
151507: torch.Size([24770, 1])
151508: torch.Size([25698, 1])
151672: torch.Size([23382, 1])
151670: torch.Size([20370, 1])
151673: torch.Size([21124, 1])
151675: torch.Size([20762, 1])
151510: torch.Size([27198, 1])
151671: torch.Size([24052, 1])
151674: torch.Size([21258, 1])
151509: torch.Size([28172, 1])


### data.x
Node feature matrix with shape `[num_nodes, num_node_features]`

#### Normalizing UMI count data

In [13]:
from dataloader import get_normalized_umi_count

normalized_data = get_normalized_umi_count(ann_data)

#### Reducing Dimensionality of data.x
Apply `PCA` Principal Component Analysis on the gene expression count to reduce the dimensionality of the data.

In [14]:
from dataloader import get_pca_reduced

reduced_data = get_pca_reduced(normalized_data, train_patients, n_components=55)
for data in reduced_data.values():
    print(data.shape)

(3460, 55)
(3661, 55)
(4226, 55)
(4384, 55)
(4015, 55)
(3498, 55)
(3639, 55)
(3592, 55)
(4634, 55)
(4110, 55)
(3673, 55)
(4789, 55)


#### Retrieving histology color information for data.x
Add it to the data.x matrix as an extra feature.

In [15]:
from dataloader import get_normalized_color_avgs

normalized_color_avgs = get_normalized_color_avgs(ann_data)

for patient_id in normalized_color_avgs.keys():
    print(normalized_color_avgs[patient_id].shape)


Variances of color0, color1, color2 =  [0.00134047 0.00493318 0.00063244]
Variances of color0, color1, color2 =  [0.00100299 0.00408081 0.00090281]
Variances of color0, color1, color2 =  [0.00058533 0.00256631 0.00049386]
Variances of color0, color1, color2 =  [0.00059817 0.00289116 0.00048598]
Variances of color0, color1, color2 =  [0.00160404 0.00482472 0.00077629]
Variances of color0, color1, color2 =  [0.00137143 0.00389194 0.00066271]
Variances of color0, color1, color2 =  [0.00084196 0.00644021 0.00071149]
Variances of color0, color1, color2 =  [0.00166604 0.00636242 0.00100769]
Variances of color0, color1, color2 =  [0.00063251 0.00230641 0.00054535]
Variances of color0, color1, color2 =  [0.00228128 0.00711623 0.00111823]
Variances of color0, color1, color2 =  [0.00123189 0.00723225 0.00067618]
Variances of color0, color1, color2 =  [0.00053539 0.00314728 0.00054947]
(3460,)
(3661,)
(4226,)
(4384,)
(4015,)
(3498,)
(3639,)
(3592,)
(4634,)
(4110,)
(3673,)
(4789,)


#### Integrating UMI count + color information
Integrate the normalized and dimensionality-reduced `UMI` count information with the `color` information to form the data.x matrix.

In [16]:
from dataloader import get_data_x

data_x = get_data_x(ann_data, reduced_data, normalized_color_avgs)
for patient_id in data_x.keys():
    print(data_x[patient_id].shape)

torch.Size([3460, 56])
torch.Size([3661, 56])
torch.Size([4226, 56])
torch.Size([4384, 56])
torch.Size([4015, 56])
torch.Size([3498, 56])
torch.Size([3639, 56])
torch.Size([3592, 56])
torch.Size([4634, 56])
torch.Size([4110, 56])
torch.Size([3673, 56])
torch.Size([4789, 56])
torch.Size([3460, 56])
torch.Size([3661, 56])
torch.Size([4226, 56])
torch.Size([4384, 56])
torch.Size([4015, 56])
torch.Size([3498, 56])
torch.Size([3639, 56])
torch.Size([3592, 56])
torch.Size([4634, 56])
torch.Size([4110, 56])
torch.Size([3673, 56])
torch.Size([4789, 56])


### data.y
Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]

In our case, the target is the brian layer for each node, so it's going to be of shape [num_nodes, 1]

As each brain layer category is a string, we should use their values to get a tensor out of them instead. This can be easily done in AnnData:

- `Layer 1` -> `0`
- `Layer 2` -> `1`
- `Layer 3` -> `2`
- `Layer 4` -> `3`
- `Layer 5` -> `4`
- `Layer 6` -> `5`
- `Layer WM` -> `6`

In [17]:
print("layer categories:" , ann_data['151676'].obs["sce.layer_guess"].head(4))
print("--------")
print("layer codes:" , ann_data['151676'].obs["sce.layer_guess"].cat.codes.head(4))

layer categories: AAACAAGTATCTCCCA-1    Layer3
AAACAATCTACTAGCA-1    Layer1
AAACACCAATAACTGC-1        WM
AAACAGAGCGACTCCT-1    Layer3
Name: sce.layer_guess, dtype: category
Categories (7, object): ['Layer1', 'Layer2', 'Layer3', 'Layer4', 'Layer5', 'Layer6', 'WM']
--------
layer codes: AAACAAGTATCTCCCA-1    2
AAACAATCTACTAGCA-1    0
AAACACCAATAACTGC-1    6
AAACAGAGCGACTCCT-1    2
dtype: int8


There are a few NaN s in the dataset. Convert them to a new layer for now.

In [18]:
from dataloader import get_data_y

data_y = get_data_y(ann_data)
for patient_id in data_y.keys():
    print(data_y[patient_id])

patient_id: 151676
patient_id: 151669
patient_id: 151507
patient_id: 151508
patient_id: 151672
patient_id: 151670
patient_id: 151673
patient_id: 151675
patient_id: 151510
patient_id: 151671
patient_id: 151674
patient_id: 151509
tensor([2, 0, 6,  ..., 6, 5, 0])
tensor([1, 3, 0,  ..., 3, 2, 0])
tensor([0, 2, 0,  ..., 6, 5, 0])
tensor([2, 0, 6,  ..., 4, 6, 0])
tensor([2, 3, 0,  ..., 3, 3, 0])
tensor([1, 3, 0,  ..., 3, 2, 0])
tensor([2, 0, 6,  ..., 5, 6, 1])
tensor([0, 2, 6,  ..., 6, 5, 6])
tensor([0, 4, 2,  ..., 2, 5, 2])
tensor([2, 3, 0,  ..., 2, 4, 0])
tensor([2, 0, 6,  ..., 5, 6, 0])
tensor([0, 2, 5,  ..., 3, 5, 1])


## data.pos
Node position matrix with shape [num_nodes, num_dimensions]

In [19]:
from dataloader import get_data_pos

data_pos = get_data_pos(ann_data)
for patient_id in data_pos.keys():
    print(data_pos[patient_id].shape)

torch.Size([3460, 2])
torch.Size([3661, 2])
torch.Size([4226, 2])
torch.Size([4384, 2])
torch.Size([4015, 2])
torch.Size([3498, 2])
torch.Size([3639, 2])
torch.Size([3592, 2])
torch.Size([4634, 2])
torch.Size([4110, 2])
torch.Size([3673, 2])
torch.Size([4789, 2])


## Creating the Data Loaders with all the gathered information
One for each group:
- Training
- Validation
- Testing

In [20]:
import yaml
from dataloader import get_dataloaders

params = yaml.safe_load(open("params.yaml"))['train']

patients = (train_patients, val_patients, test_patients)

train_loader, val_loader, test_loader = get_dataloaders(patients, data_x, \
                                                        edge_indices, edge_features, \
                                                        data_pos, data_y, params)

In [21]:
num_classes = 8
layer_names = ['Layer I', 'Layer II', 'Layer III', 'Layer IV', 'Layer V', 'Layer VI', 'White Matter', 'Unknown']

def class_ocurrences(loader, num_classes):
    class_counts = torch.zeros(num_classes, dtype=torch.long)

    for batch in loader:
        class_counts += torch.bincount(batch.y, minlength=num_classes)

    total_samples = class_counts.sum().item()
    print(total_samples)

    for i in range(num_classes):
        count = class_counts[i].item()
        percentage = (count / total_samples) * 100
        print(f"{i:<8} {layer_names[i]:<12} {count:<8} {percentage:.2f}%")

        
class_ocurrences(train_loader, num_classes)

31731
0        Layer I      8631     27.20%
1        Layer II     2541     8.01%
2        Layer III    7977     25.14%
3        Layer IV     3494     11.01%
4        Layer V      4121     12.99%
5        Layer VI     2831     8.92%
6        White Matter 2037     6.42%
7        Unknown      99       0.31%


### Optional: Visualize one graph

Takes a lot of time to compute

In [22]:
# from dataloader import visualize_data


# visualize_data(next(iter(train_loader)))

## Import Model

In [23]:
from model import get_model

with open('params.yaml') as parfile:
    all_params = yaml.safe_load(parfile)
    featurize_params = all_params['featurize']
    train_params = all_params['train']
    train_params['pca_components'] = featurize_params['pca_components']
    train_params['seed'] = featurize_params['seed']
    tracking_params = all_params['tracking']

device, model = get_model(train_params)

## Optimizer, Loss and Scheduler

In [24]:
%load_ext autoreload
%autoreload 2
from model import get_criterion, get_optimizer, get_scheduler

# obtain class frequencies
class_counts = torch.zeros(model.num_classes, dtype=torch.long)
for batch in train_loader:
    class_counts += torch.bincount(batch.y, minlength=model.num_classes)
total = class_counts.sum().item()
class_freqs = class_counts.float() / total
class_freqs = class_freqs.to(device)

optimizer = get_optimizer(model, train_params)
criterion = get_criterion(class_freqs, train_params)
scheduler = get_scheduler(optimizer, train_params)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Setting experiment tracking

We use [mlflow](https://mlflow.org)

In [25]:
from mlflow_server import start_mlflow_server, stop_mlflow_server

port="5000"
artifacts_dir = "artifacts"
pid_file_path = "mlflow.pid"
log_dir="logs"
experiment_name = "BrainLayerClassifier"

start_mlflow_server(port=port, artifacts_dir=artifacts_dir, pid_file_path=pid_file_path)

MLflow server started (PID 234060)


## Training

In [26]:
from train import train_loop, start_tracking_experiment


writer = start_tracking_experiment(exp_name=experiment_name, port=port, log_dir=log_dir)

loaders = (train_loader, val_loader)
train_loop(model, optimizer, criterion, scheduler, loaders, device, train_params, writer)

epoch:  1


  num_correct = mask.new_zeros(num_classes).scatter_(0, target, mask, reduce="add")


Epoch 001 : Loss: 2.2319 | Val Loss: 2.0663 | Val Acc: 0.1197
epoch:  2
Epoch 002 : Loss: 2.1322 | Val Loss: 2.0384 | Val Acc: 0.1537
epoch:  3
Epoch 003 : Loss: 2.1098 | Val Loss: 2.0293 | Val Acc: 0.1054
epoch:  4
Epoch 004 : Loss: 2.0873 | Val Loss: 2.0190 | Val Acc: 0.0950
epoch:  5
Epoch 005 : Loss: 2.0663 | Val Loss: 2.0100 | Val Acc: 0.0979
epoch:  6
Epoch 006 : Loss: 2.0463 | Val Loss: 2.0011 | Val Acc: 0.1343
epoch:  7
Epoch 007 : Loss: 2.0235 | Val Loss: 1.9901 | Val Acc: 0.2128
epoch:  8
Epoch 008 : Loss: 1.9968 | Val Loss: 1.9762 | Val Acc: 0.2751
epoch:  9
Epoch 009 : Loss: 1.9687 | Val Loss: 1.9580 | Val Acc: 0.2666
epoch:  10
Epoch 010 : Loss: 1.9505 | Val Loss: 1.9379 | Val Acc: 0.2168
epoch:  11
Epoch 011 : Loss: 1.9440 | Val Loss: 1.9184 | Val Acc: 0.2260
epoch:  12
Epoch 012 : Loss: 1.9118 | Val Loss: 1.8988 | Val Acc: 0.2583
epoch:  13
Epoch 013 : Loss: 1.8935 | Val Loss: 1.8781 | Val Acc: 0.3086
epoch:  14
Epoch 014 : Loss: 1.8825 | Val Loss: 1.8561 | Val Acc: 0.35

epoch:  115
Epoch 115 : Loss: 0.7245 | Val Loss: 0.8502 | Val Acc: 0.7622
epoch:  116
Epoch 116 : Loss: 0.7141 | Val Loss: 0.8381 | Val Acc: 0.7691
epoch:  117
Epoch 117 : Loss: 0.6993 | Val Loss: 0.8369 | Val Acc: 0.7589
epoch:  118
Epoch 118 : Loss: 0.6980 | Val Loss: 0.8374 | Val Acc: 0.7549
epoch:  119
Epoch 119 : Loss: 0.6841 | Val Loss: 0.8314 | Val Acc: 0.7598
epoch:  120
Epoch 120 : Loss: 0.6828 | Val Loss: 0.8201 | Val Acc: 0.7710
epoch:  121
Epoch 121 : Loss: 0.6860 | Val Loss: 0.8051 | Val Acc: 0.7749
epoch:  122
Epoch 122 : Loss: 0.6703 | Val Loss: 0.7981 | Val Acc: 0.7686
epoch:  123
Epoch 123 : Loss: 0.6641 | Val Loss: 0.8017 | Val Acc: 0.7617
epoch:  124
Epoch 124 : Loss: 0.6571 | Val Loss: 0.8106 | Val Acc: 0.7588
epoch:  125
Epoch 125 : Loss: 0.6506 | Val Loss: 0.8182 | Val Acc: 0.7613
epoch:  126
Epoch 126 : Loss: 0.6538 | Val Loss: 0.8183 | Val Acc: 0.7632
epoch:  127
Epoch 127 : Loss: 0.6465 | Val Loss: 0.8139 | Val Acc: 0.7667
epoch:  128
Epoch 128 : Loss: 0.6384 |

epoch:  226
Epoch 226 : Loss: 0.4158 | Val Loss: 0.6969 | Val Acc: 0.7993
epoch:  227
Epoch 227 : Loss: 0.4114 | Val Loss: 0.7004 | Val Acc: 0.7992
epoch:  228
Epoch 228 : Loss: 0.4058 | Val Loss: 0.7066 | Val Acc: 0.7992
epoch:  229
Epoch 229 : Loss: 0.4142 | Val Loss: 0.7114 | Val Acc: 0.7991
epoch:  230
Epoch 230 : Loss: 0.4098 | Val Loss: 0.7109 | Val Acc: 0.7985
epoch:  231
Epoch 231 : Loss: 0.4056 | Val Loss: 0.7098 | Val Acc: 0.7987
epoch:  232
Epoch 232 : Loss: 0.4093 | Val Loss: 0.7085 | Val Acc: 0.7976
epoch:  233
Epoch 233 : Loss: 0.4071 | Val Loss: 0.7064 | Val Acc: 0.7975
epoch:  234
Epoch 234 : Loss: 0.4061 | Val Loss: 0.7032 | Val Acc: 0.7973
epoch:  235
Epoch 235 : Loss: 0.4007 | Val Loss: 0.7021 | Val Acc: 0.7971
epoch:  236
Epoch 236 : Loss: 0.4017 | Val Loss: 0.7014 | Val Acc: 0.7982
epoch:  237
Epoch 237 : Loss: 0.4062 | Val Loss: 0.7009 | Val Acc: 0.7989
epoch:  238
Epoch 238 : Loss: 0.3982 | Val Loss: 0.7008 | Val Acc: 0.7993
epoch:  239
Epoch 239 : Loss: 0.3977 |

## Stopping experiment tracking

In [27]:
from mlflow_server import stop_mlflow_server

stop_mlflow_server(pid_file_path=pid_file_path)

MLflow server stopped (PID 234060)
