# Running cytoself in Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/royerlab/cytoself/blob/main/example_scripts/simple_example.ipynb)

## Introduction
This jupyter notebook shows a simple example of how to use cytoself with a few example images and pre-trained model.
Please note that this notebook is only to demonstrate how to run cytoself, and the example data offered here is too small to train a usable model.
You won't be able to see reasonable training results from this notebook.
To train a better cytoself model, please download more data from [**Data Availability**](https://github.com/royerlab/cytoself#data-availability).


## Example demo

Let's get started with a simple example.

Note: In case an error occurs, which is observed occasionally, please be patient and try to run the cell again. If the error persists, please try to restart the runtime.

### 0. Prepare environment (required for Google Colab)

In [None]:
#install python 3.9
!sudo apt-get update -y
!sudo apt-get install python3.9
!wget https://bootstrap.pypa.io/get-pip.py
!python3.9 get-pip.py

#change alternatives
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2

!pip install --upgrade git+https://github.com/royerlab/cytoself.git

### 1. Prepare data

In [None]:
from os.path import join
import numpy as np
import torch
import matplotlib.pyplot as plt

from cytoself.datamanager.opencell import DataManagerOpenCell
from cytoself.trainer.cytoselflite_trainer import CytoselfFullTrainer
from cytoself.analysis.analysis_opencell import AnalysisOpenCell
from cytoself.trainer.utils.plot_history import plot_history_cytoself

data_ch = ['pro', 'nuc']
datapath = 'sample_data'  # path to download sample data
DataManagerOpenCell.download_sample_data(datapath)  # donwload data
datamanager = DataManagerOpenCell(datapath, data_ch, fov_col=None)
datamanager.const_dataloader(batch_size=32, label_name_position=1)

### 2. Create and train a cytoself model

In [None]:
model_args = {
    'input_shape': (2, 100, 100),
    'emb_shapes': ((25, 25), (4, 4)),
    'output_shape': (2, 100, 100),
    'fc_output_idx': [2],
    'vq_args': {'num_embeddings': 512, 'embedding_dim': 64},
    'num_class': len(datamanager.unique_labels),
    'fc_input_type': 'vqvec',
}
train_args = {
    'lr': 1e-3,
    'max_epoch': 1,
    'reducelr_patience': 3,
    'reducelr_increment': 0.1,
    'earlystop_patience': 6,
}
trainer = CytoselfFullTrainer(train_args, homepath='demo_output', model_args=model_args)
trainer.fit(datamanager, tensorboard_path='tb_logs')

### 2.1 Generate training history

In [None]:
plot_history_cytoself(trainer.history, savepath=trainer.savepath_dict['visualization'])

### 2.2 Compare the reconstructed images as a sanity check

In [None]:
img = next(iter(datamanager.test_loader))['image'].detach().cpu().numpy()
torch.cuda.empty_cache()
reconstructed = trainer.infer_reconstruction(img)
fig, ax = plt.subplots(2, len(data_ch), figsize=(5 * len(data_ch), 5), squeeze=False)
for ii, ch in enumerate(data_ch):
    t0 = np.zeros((2 * 100, 5 * 100))
    for i, im in enumerate(img[:10, ii, ...]):
        i0, i1 = np.unravel_index(i, (2, 5))
        t0[i0 * 100 : (i0 + 1) * 100, i1 * 100 : (i1 + 1) * 100] = im
    t1 = np.zeros((2 * 100, 5 * 100))
    for i, im in enumerate(reconstructed[:10, ii, ...]):
        i0, i1 = np.unravel_index(i, (2, 5))
        t1[i0 * 100 : (i0 + 1) * 100, i1 * 100 : (i1 + 1) * 100] = im
    ax[0, ii].imshow(t0, cmap='gray')
    ax[0, ii].axis('off')
    ax[0, ii].set_title('input ' + ch)
    ax[1, ii].imshow(t1, cmap='gray')
    ax[1, ii].axis('off')
    ax[1, ii].set_title('output ' + ch)
fig.tight_layout()
fig.show()
fig.savefig(join(trainer.savepath_dict['visualization'], 'reconstructed_images.png'), dpi=300)

### 3. Analyze embeddings

In [None]:
analysis = AnalysisOpenCell(datamanager, trainer)

### 3.1 Generate bi-clustering heatmap

In [None]:
analysis.plot_clustermap(num_workers=4)

### 3.2 Generate feature spectrum

In [None]:
vqindhist1 = trainer.infer_embeddings(img, 'vqindhist1')
ft_spectrum = analysis.compute_feature_spectrum(vqindhist1)

x_max = ft_spectrum.shape[1] + 1
x_ticks = np.arange(0, x_max, 50)
fig, ax = plt.subplots(figsize=(10, 3))
ax.stairs(ft_spectrum[0], np.arange(x_max), fill=True)
ax.spines[['right', 'top']].set_visible(False)
ax.set_xlabel('Feature index')
ax.set_ylabel('Counts')
ax.set_xlim([0, x_max])
ax.set_xticks(x_ticks, analysis.feature_spectrum_indices[x_ticks])
fig.tight_layout()
fig.show()
fig.savefig(join(analysis.savepath_dict['feature_spectra_figures'], 'feature_spectrum.png'), dpi=300)

### 3.3 Plot UMAP

In [None]:
umap_data = analysis.plot_umap_of_embedding_vector(
    data_loader=datamanager.test_loader,
    group_col=2,
    output_layer=f'{model_args["fc_input_type"]}2',
    title=f'UMAP {model_args["fc_input_type"]}2',
    xlabel='UMAP1',
    ylabel='UMAP2',
    s=0.3,
    alpha=0.5,
    show_legend=True,
)
