# Training GCN
Train a Graph Neural Network with the histology + gene information.

We are going to create a brain layer classifier using [Torch Geometric's GCN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GCN.html?highlight=gcn#torch-geometric-nn-models-gcn)

## Loading the data

For a detailed exploration and analysis of what the data actually contains, visit the `DataAnalysis` notebook, located in the same directory as this one.

### Downloading the data

In [1]:
%%sh
./dataset/getdata.sh

~/projects/GNNCellClassification/dataset ~/projects/GNNCellClassification
Don't download data: Both data and images exists
~/projects/GNNCellClassification


### Actual loading

In [2]:
import sys, os
sys.path.append(os.path.abspath("src"))

In [3]:
%load_ext autoreload
%autoreload 2
import importlib
import preprocess

data_dir, img_dir, graph_dir = "dataset/data", "dataset/images", "out/graphs"
ann_data, histology_imgs = preprocess.main(data_dir, img_dir, graph_dir)

Loading AnnData for sample 151676 …
Loading AnnData for sample 151669 …
Loading AnnData for sample 151507 …
Loading AnnData for sample 151508 …
Loading AnnData for sample 151672 …
Loading AnnData for sample 151670 …
Loading AnnData for sample 151673 …
Loading AnnData for sample 151675 …
Loading AnnData for sample 151510 …
Loading AnnData for sample 151671 …
Loading AnnData for sample 151674 …
Loading AnnData for sample 151509 …
Loading Image for sample 151676 …
Loading Image for sample 151669 …
Loading Image for sample 151507 …
Loading Image for sample 151508 …
Loading Image for sample 151672 …
Loading Image for sample 151670 …
Loading Image for sample 151673 …
Loading Image for sample 151675 …
Loading Image for sample 151510 …
Loading Image for sample 151671 …
Loading Image for sample 151674 …
Loading Image for sample 151509 …


## Features used for training

We are going to use the following features for training:

**Edge Features**:
- Spatial Connectivities` between spots
- Pixel `distance` (adjusted by color)

**Node Features**:
- `UMI` count (log)
- `Color` in the `neighbourhood` of the spot


## PyTorch Geometric's data structure

From the official [Documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GCN.html?highlight=gcn#torch-geometric-nn-models-gcn) , we can see that:

A single graph in PyG is described by an instance of `torch_geometric.data.Data`, which holds the following attributes by default:

- `data.x`: Node feature matrix with shape `[num_nodes, num_node_features]`

- `data.edge_index`: Graph connectivity in `COO` format with shape `[2, num_edges]` and type `torch.long`

- `data.edge_attr`: Edge feature matrix with shape `[num_edges, num_edge_features]`

- `data.y`: Target to train against (may have arbitrary shape), e.g., node-level targets of shape `[num_nodes, *]` or graph-level targets of shape `[1, *]`

- `data.pos`: Node position matrix with shape `[num_nodes, num_dimensions]`

## Creating the required data structures for training

We need to convert the data to what's required by PyTorch geometric.

In [4]:
import torch

### data.edge_index
We need to transform to a `PyTorch` tensor in `COO` format.
Let's start with a reference patient:

In [5]:
ann_data['151676'].obsp['spatial_connectivities']

<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 20052 stored elements and shape (3460, 3460)>

In [6]:
type(ann_data['151676'].obsp['spatial_connectivities'])

scipy.sparse._csr.csr_matrix

In [7]:
coo_matrix = ann_data['151676'].obsp['spatial_connectivities'].tocoo()
coo_matrix

<COOrdinate sparse matrix of dtype 'float64'
	with 20052 stored elements and shape (3460, 3460)>

In [8]:
type(coo_matrix)

scipy.sparse._coo.coo_matrix

In [9]:
coo_connections = { patient: data.obsp['spatial_connectivities'].tocoo()  \
                   for patient, data in ann_data.items() }
for patient, coo in coo_connections.items():
    print(f"Patient {patient}: {coo.shape}")

Patient 151676: (3460, 3460)
Patient 151669: (3661, 3661)
Patient 151507: (4226, 4226)
Patient 151508: (4384, 4384)
Patient 151672: (4015, 4015)
Patient 151670: (3498, 3498)
Patient 151673: (3639, 3639)
Patient 151675: (3592, 3592)
Patient 151510: (4634, 4634)
Patient 151671: (4110, 4110)
Patient 151674: (3673, 3673)
Patient 151509: (4789, 4789)


In [10]:
edge_indices = {}

for patient, coo in coo_connections.items():
    row = torch.from_numpy(coo.row).long()
    col = torch.from_numpy(coo.col).long()
    edge_indices[patient] = torch.stack([row, col], dim=0)

    print(f"{patient}: {edge_indices[patient].shape}")

151676: torch.Size([2, 20052])
151669: torch.Size([2, 21194])
151507: torch.Size([2, 24770])
151508: torch.Size([2, 25698])
151672: torch.Size([2, 23382])
151670: torch.Size([2, 20370])
151673: torch.Size([2, 21124])
151675: torch.Size([2, 20762])
151510: torch.Size([2, 27198])
151671: torch.Size([2, 24052])
151674: torch.Size([2, 21258])
151509: torch.Size([2, 28172])


## data.edge_attr
Edge feature matrix with shape `[num_edges, num_edge_features]`.
For now, only get the distances for the ones that are spatially connected.

In [11]:
import numpy as np
import os

edge_features = {}
for patient in ann_data.keys():
    filename = str(f"{patient}_adj.npy")
    adj_distances = np.load(os.path.join(graph_dir, filename))
    adj_tensor = torch.from_numpy(adj_distances)
    row, col = edge_indices[patient]
    distances = adj_tensor[row, col]
    edge_features[patient] = distances.unsqueeze(1).float()
    print(f"{patient}: {edge_features[patient].shape}")

151676: torch.Size([20052, 1])
151669: torch.Size([21194, 1])
151507: torch.Size([24770, 1])
151508: torch.Size([25698, 1])
151672: torch.Size([23382, 1])
151670: torch.Size([20370, 1])
151673: torch.Size([21124, 1])
151675: torch.Size([20762, 1])
151510: torch.Size([27198, 1])
151671: torch.Size([24052, 1])
151674: torch.Size([21258, 1])
151509: torch.Size([28172, 1])


## data.x
Node feature matrix with shape `[num_nodes, num_node_features]`

In [12]:
import scanpy as sc

node_features = {}
for patient, data in ann_data.items():
    sc.pp.normalize_total(ann_data[patient])
    sc.pp.log1p(ann_data[patient])

    node_features[patient] = ann_data[patient].X.todense()
    print(node_features[patient].shape)

(3460, 33538)
(3661, 33538)
(4226, 33538)
(4384, 33538)
(4015, 33538)
(3498, 33538)
(3639, 33538)
(3592, 33538)
(4634, 33538)
(4110, 33538)
(3673, 33538)
(4789, 33538)


In [15]:
from graph import get_region_colors

offsets = {'151676': 310, '151669': 276, '151507': 236, '151508': 232, '151672': 264, \
        '151670': 339, '151673': 260, '151675': 228, '151510': 204, '151671': 238, \
        '151674': 234, '151509': 220}
thickness = 48

normalized_color_avgs = {}
for patient_id, data in ann_data.items():
    offset = offsets[patient_id]
    hires_scale = ann_data[patient_id].uns['spatial'][patient_id]['scalefactors']['tissue_hires_scalef']
    spot_pixels = ann_data[patient_id].obsm['spatial'] * hires_scale
    spot_pixels = spot_pixels.astype(int)
    
    image = ann_data[patient_id].uns['spatial'][patient_id]['images']['hires']
    hires_shape = image.shape
    assert min(hires_shape[0], hires_shape[1]) > spot_pixels.max()

    flipped_image = np.flip(image, 0)
    x_pixels = spot_pixels[:, 0]
    y_pixels = spot_pixels[:, 1]
    
    normalized_color_avgs[patient_id] = get_region_colors(x_pixels, y_pixels, offset=offsets[patient_id], image=flipped_image, thickness=thickness, alpha=1)

Variances of color0, color1, color2 =  [0.00134047 0.00493318 0.00063244]
Variances of color0, color1, color2 =  [0.00100299 0.00408081 0.00090281]
Variances of color0, color1, color2 =  [0.00058533 0.00256631 0.00049386]
Variances of color0, color1, color2 =  [0.00059817 0.00289116 0.00048598]
Variances of color0, color1, color2 =  [0.00160404 0.00482472 0.00077629]
Variances of color0, color1, color2 =  [0.00137143 0.00389194 0.00066271]
Variances of color0, color1, color2 =  [0.00084196 0.00644021 0.00071149]
Variances of color0, color1, color2 =  [0.00166604 0.00636242 0.00100769]
Variances of color0, color1, color2 =  [0.00063251 0.00230641 0.00054535]
Variances of color0, color1, color2 =  [0.00228128 0.00711623 0.00111823]
Variances of color0, color1, color2 =  [0.00123189 0.00723225 0.00067618]
Variances of color0, color1, color2 =  [0.00053539 0.00314728 0.00054947]
