# Running ONTraC on MERFISH dataset

## Notes

This notebook will show you the process of running ONTraC on simulation data.

We assume that you have installed ONTraC according to the [installation tutorial](../../tutorials/installation.md) and open this notebook using installed Python kernel (Python 3.11 (ONTraC)).

## Running ONTraC on MERFISH data

If your default shell is not Bash, please adjust this code.

ONTraC will run on CPU if CUDA is not available.

Warning: The MERFISH dataset is quite large and will take a long time to run on CPU only.

In [None]:
%%bash

source ~/.bash_profile
conda activate ONTraC
ONTraC -d merfish_dataset.csv --preprocessing-dir merfish_preprocessing_dir --GNN-dir merfish_GNN --NTScore-dir merfish_NTScore \
       --device cuda --epochs 1000 --batch-size 5 -s 42 --patience 100 --min-delta 0.001 --min-epochs 50 --lr 0.03 --hidden-feats 4 -k 6 \
       --modularity-loss-weight 0.3 --regularization-loss-weight 0.1 --purity-loss-weight 300 --beta 0.03 2>&1 | tee log/merfish.log

## Results visualization

Please see [post analysis tutorial](../../tutorials/post_analysis.md) for details.

### Install required packages

If you default sh is not bash, please adjust this code


In [None]:
%%bash

source ~/.bash_profile
conda activate ONTraC

pip install matplotlib seaborn

### Loading results

In [1]:
from optparse import Values
from typing import List, Tuple

import numpy as np
import pandas as pd

from ONTraC.utils import read_yaml_file, get_rel_params


def load_data(options: Values) -> pd.DataFrame:
    """
    load data after ONTraC processing
    :param options, Values. dataset, preprocessing_dir and NTScore_dif needed.
    :return data_df, pd.DataFrame
    """
    data_df = pd.DataFrame()
    params = read_yaml_file(f'{options.preprocessing_dir}/samples.yaml')
    rel_params = get_rel_params(options, params)
    cell_type_code_df = pd.read_csv(f'{options.preprocessing_dir}/cell_type_code.csv', index_col=0)
    for sample in rel_params['Data']:
        NTScore_df = pd.read_csv(f'{options.NTScore_dif}/{sample["Name"]}_NTScore.csv.gz', index_col=0)
        cell_type_composition_df = pd.read_csv(sample['Features'], header=None)
        cell_type_composition_df.columns = cell_type_code_df.loc[np.arange(cell_type_composition_df.shape[1]), 'Cell_Type'].tolist()
        sample_df = pd.concat([NTScore_df.reset_index(drop=True), cell_type_composition_df], axis=1)
        sample_df.index = NTScore_df.index
        sample_df['sample'] = [sample["Name"]] * sample_df.shape[0]
        data_df = pd.concat([data_df, sample_df])
    
    raw_df = pd.read_csv(options.dataset, index_col=0)
    data_df = data_df.join(raw_df[['Cell_Type']])
    return data_df



In [None]:
options = Values()
options.dataset = 'merfish_dataset.csv'
options.preprocessing_dir = 'merfish_preprocessing_dir'
options.NTScore_dif = 'merfish_NTScore'

data_df = load_data(options = options)
samples = data_df['sample'].unique().tolist()
cell_types = data_df['Cell_Type'].unique().tolist()

### Plotting prepare

In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.sans-serif'] = 'Arial'
import seaborn as sns

### Spatial cell type distribution

In [None]:
selected_cell_types = ["VLMC", 'L2/3 IT', 'L4/5 IT', 'L5 IT',"L5 ET", "L5/6 NP" , 'L6 IT',"L6 CT","L6 IT Car3"]

rainbow_cmap = mpl.colormaps['gist_rainbow']

my_pal = {"VLMC": rainbow_cmap(0)}
my_pal.update({cell_type: rainbow_cmap( 0.3 + 0.7 * (i - 1) / (len(selected_cell_types) - 1)) for i, cell_type in enumerate(selected_cell_types[1:])})
my_pal.update({cell_type: 'gray' for cell_type in cell_types if cell_type not in selected_cell_types})

In [None]:
with sns.axes_style('white', rc={
        'xtick.bottom': True,
        'ytick.left': True
}), sns.plotting_context('paper',
                         rc={
                             'axes.titlesize': 8,
                             'axes.labelsize': 8,
                             'xtick.labelsize': 6,
                             'ytick.labelsize': 6,
                             'legend.fontsize': 6
                         }):
    N = len(samples)
    fig, axes = plt.subplots(1, N, figsize = (4 * N, 3))
    for i, sample in enumerate(samples):
        sample_df = data_df.loc[data_df['sample'] == sample]
        ax = axes[i] if N > 1 else axes
        sns.scatterplot(data = sample_df,
                        x = 'x',
                        y = 'y',
                        hue = 'Cell_Type',
                        palette = my_pal,
                        hue_order = selected_cell_types + [x for x in cell_types if x not in selected_cell_types],
                        edgecolor=None,
                        s = 4,
                        ax=ax)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f"{sample}")
        ax.legend(loc='upper left', bbox_to_anchor=(1,1))


    fig.tight_layout()
    fig.savefig('spatial_cell_type.png', dpi=300)

![spatial cell type distribution](img/merfish_spatial_cell_type.png)

### Cell-level NT score spatial distribution

In [None]:
N = len(samples)
fig, axes = plt.subplots(1, N, figsize = (3.5 * N, 3))
for i, sample in enumerate(samples):
    sample_df = data_df.loc[data_df['sample'] == sample]
    ax = axes[i] if N > 1 else axes
    scatter = ax.scatter(sample_df['x'], sample_df['y'], c=1 - sample_df['Cell_NTScore'], cmap='rainbow', vmin=0, vmax=1, s=1) # substitute with following line if you don't need change the direction of NT score
    # scatter = ax.scatter(sample_df['x'], sample_df['y'], c=sample_df['Cell_NTScore'], cmap='rainbow', vmin=0, vmax=1, s=1)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.colorbar(scatter)
    ax.set_title(f"{sample} cell-level NT score")


fig.tight_layout()
fig.savefig('cell_level_NT_score.png', dpi=300)

![cell-level NT score](img/merfish_cell_level_NT_score.png)