# Running ONTraC on MERFISH dataset

## Notes

This notebook will show you the process of running ONTraC on simulation data.

# ONTraC installation

We assume that you have installed ONTraC based on following instructions and open this notebook using installed Python kernel (Python 3.11 (ONTraC)).

In [None]:
conda create -n ONTraC python=3.11
conda activate ONTraC
pip install "ONTraC[analysis]==1.*"
pip install ipykernel
python -m ipykernel install --user --name ONTraC --display-name "Python 3.11 (ONTraC)"

## Running ONTraC on MERFISH data

ONTraC will run on CPU if CUDA is not available.

Download `merfish_dataset_meta_input.csv` from [Zenodo](https://zenodo.org/records/XXXXXX)

In [None]:
%%bash

source ~/.bash_profile
conda activate ONTraC
ONTraC --meta-input merfish_dataset_meta_input.csv --preprocessing-dir merfish_preprocessing --GNN-dir merfish_GNN --NTScore-dir merfish_NTScore \
       --device cuda --epochs 1000 --batch-size 10 -s 42 --patience 100 --min-delta 0.001 --min-epochs 50 --lr 0.03 --hidden-feats 4 -k 6 \
       --modularity-loss-weight 1 --purity-loss-weight 30 --regularization-loss-weight 0.1 --beta 0.3 2>&1 | tee merfish.log

## Results visualization

We only show two simple examples here, please see [post analysis tutorial](../../tutorials/post_analysis.md) for details and more figures.

### Plotting prepare

In [None]:
import numpy as np
import pandas as pd

import matplotlib as mpl

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'Arial'
import matplotlib.pyplot as plt
import seaborn as sns

from ONTraC.analysis.data import AnaData

### Loading ONTraC results

In [1]:
from optparse import Values

options = Values()
options.preprocessing_dir = 'train_backup/V2/V1_reproduce_merfish_dataset/merfish_preprocessing/'
options.GNN_dir = 'train_backup/V2/V1_reproduce_merfish_dataset/merfish_GNN/'
options.NTScore_dir = 'train_backup/V2/V1_reproduce_merfish_dataset/merfish_NTScore/'
options.log = 'train_backup/V2/V1_reproduce_merfish_dataset/merfish.log'
options.reverse = True  # Set it to False if you don't want reverse NT score

ana_data = AnaData(options)

### Spatial cell type distribution

In [None]:
selected_cell_types = ["VLMC", 'L2/3 IT', 'L4/5 IT', 'L5 IT',"L5 ET", "L5/6 NP" , 'L6 IT',"L6 CT","L6 IT Car3"]

rainbow_cmap = mpl.colormaps['gist_rainbow']

my_pal = {"VLMC": rainbow_cmap(0)}
my_pal.update({cell_type: rainbow_cmap( 0.3 + 0.7 * (i - 1) / (len(selected_cell_types) - 1)) for i, cell_type in enumerate(selected_cell_types[1:])})
my_pal.update({cell_type: 'gray' for cell_type in cell_types if cell_type not in selected_cell_types})

samples = ['mouse1_slice91', 'mouse1_slice131']

In [None]:
data_df = ana_data.meta_data

N = len(samples)
fig, axes = plt.subplots(1, N, figsize = (5 * N, 3))
for i, sample in enumerate(samples):
    sample_df = data_df.loc[data_df['Sample'] == sample]
    ax = axes[i] if N > 1 else axes
    sns.scatterplot(data = sample_df,
                    x = 'x',
                    y = 'y',
                    hue = 'Cell_Type',
                    palette = my_pal,
                    hue_order = selected_cell_types + [x for x in cell_types if x not in selected_cell_types],
                    edgecolor=None,
                    s = 4,
                    ax=ax)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"{sample}")
    ax.legend(loc='upper left', bbox_to_anchor=(1,1), ncol=2)


fig.tight_layout()
fig.savefig('spatial_cell_type.png', dpi=300)

![spatial cell type distribution](img/merfish_spatial_cell_type.png)

### Cell-level NT score spatial distribution

In [None]:
data_df = ana_data.NT_score

N = len(samples)
fig, axes = plt.subplots(1, N, figsize = (3.5 * N, 3))
for i, sample in enumerate(samples):
    sample_df = data_df.loc[data_df['Sample'] == sample]
    ax = axes[i] if N > 1 else axes
    scatter = ax.scatter(sample_df['x'], sample_df['y'], c=1 - sample_df['Cell_NTScore'], cmap='rainbow', vmin=0, vmax=1, s=1) # substitute with following line if you don't need change the direction of NT score
    # scatter = ax.scatter(sample_df['x'], sample_df['y'], c=sample_df['Cell_NTScore'], cmap='rainbow', vmin=0, vmax=1, s=1)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.colorbar(scatter)
    ax.set_title(f"{sample} cell-level NT score")


fig.tight_layout()
fig.savefig('cell_level_NT_score.png', dpi=300)

![cell-level NT score](img/merfish_cell_level_NT_score.png)