# Getting Started: Multi-Task Decoding (Task 3)


<br>
<a href="https://colab.research.google.com/github/MTNeuro/MTNeuro/blob/main/notebooks/task3_getting_started.ipynb">
    <img align="left" alt="Open in Colab button" src="https://colab.research.google.com/assets/colab-badge.svg" width="150" height="60">
</a>
<br>    

This **MTNeuro** jupyter notebook takes you through how you can execute `task 3`. It takes in an encoder and computes R2 scores between embeddings and different Semantic features.

For more details on the tasks and dataset, please refer to our paper:

    "Quesada, J., Sathidevi, L., Liu, R., Ahad, N., Jackson, J.M., Azabou, M., ... & Dyer, E. L. (2022). MTNeuro: A Benchmark for Evaluating Representations of Brain Structure Across Multiple Levels of Abstraction. Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track."


### Clone MTNeuro repo and Install the MTNeuro package

In [None]:
!git clone https://github.com/MTNeuro/MTNeuro && cd MTNeuro && pip install -e .
!pip install pandas
!pip install tensorboard

### Import the Required Packages

In [7]:
#Import libraries
import os
import sys
import json
import torch
import math
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
import argparse
import umap

#PyTorch imports
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import torch.nn.functional as F

#Sci-kit learn imports
from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from sklearn import preprocessing
from sklearn.decomposition import PCA

#MTNeuro modules
#sys.path.append('./MTNeuro')                 #setting the location to look for the required packages
from MTNeuro.annots.features import extract_cell_stats,extract_axon_stats,extract_blood_stats
from MTNeuro.annots.get_cutouts import get_cutout_data
from MTNeuro.annots.latents import get_latents, get_unsup_latents


### Load the Encoder

The encoder file path **and** the encoder type is required here. 

Any of the models used in Task 1 can be used as an encoder. The model weight can be found here: [[Dropbox](https://www.dropbox.com/sh/jkk1i9wopqqrgne/AABHDICD0Cfwl_wm5q2ueIS8a)]

The encoder types are `ssl`, `supervised`, `PCA`, and `NMF`.

In [None]:
#In this demo, BYOL is used. 
#TODO: Download the model weights for BYOL from the Dropbox and place the file in the same directory as this notebook.
!wget https://www.dropbox.com/s/htjw410bk0grhj4/ckpt-800.pt

encoder_file_path = "ckpt-800.pt"
encoder_type = "ssl"

#--Do not motify below--
if encoder_type == 'ssl':
    ssl_encoder = 1
    unsupervised = 0
elif encoder_type == 'supervised':
    ssl_encoder =  0
    unsupervised = 0
elif encoder_type == 'PCA': 
    unsupervised = 1
    set_pca = 1
elif encoder_type == 'NMF':
    unsupervised = 1
    set_pca = 0
else:
    print("Incorrectly specified encoder type")

### Specify Cutout Coordinates

Cutout coordinates are specified in the task 3 JSON file found at MTNeuro/taskconfig/task3.json.

In [9]:
config_file_path = "../MTNeuro/taskconfig/task3.json"

try:
    jsonFile = open(config_file_path, 'r')
    slices = json.load(jsonFile)
except IOError:
    print("JSON file not found.")
jsonFile.close()    

#### Downloading Annotations Specified in JSON File

Each cutout from each region of the brain has shape (Z,Y,X) = (360,256,256). 

All 4 cutouts are concatenated together along the z-axis to form `data_array_raw` and `data_array_anno`, which both have shape (1440,256,256).

In [None]:
#Retrieving Data from JSON
xrange_list = [slices['xrange_cor'],slices['xrange_stri'],slices['xrange_vp'],slices['xrange_zi']]
yrange_list = [slices['yrange_cor'],slices['yrange_stri'],slices['yrange_vp'],slices['yrange_zi']]
class_list = ["Cortex","Striatum","VP","ZI"]
zrange = slices['zrange']


boss_dict = {}
boss_dict['image_chan']=slices['image_chan']
boss_dict['annotation_chan'] = slices['annotation_chan']

data_array_raw = []
data_array_anno = []
label_array  = []
up_sample = 4 

#Pulling Data from BossDB
for i in range(0,len(xrange_list)):
    cutout_data_raw,cutout_data_anno = get_cutout_data(xrange_list[i],yrange_list[i],zrange,name=class_list[i])
    
    data_raw = cutout_data_raw[:,:,:]
    data_anno = cutout_data_anno[:,:,:]
    data_array_raw = np.concatenate((data_array_raw,data_raw),axis =0 ) if len(data_array_raw) else data_raw 
    
    data_array_anno = np.concatenate((data_array_anno,data_anno),axis =0 ) if len(data_array_anno) else data_anno
    
    labels = i*np.ones(up_sample*len(data_raw)).reshape(-1,)
    label_array  = np.concatenate((label_array ,labels),axis =0) if len(labels) else labels_train

### Extracting Features and Calculating Linear Readout Scores

In [None]:
print('Extracting cell stats...')
stats_cell= extract_cell_stats(np.copy(data_array_anno))

print('Extracting axon stats...')
stats_axon = extract_axon_stats(np.copy(data_array_anno))

print('Extracting blood stats...')
stats_blood = extract_blood_stats(np.copy(data_array_anno))

'get results for different encoders'
if encoder_type == 'ssl' or encoder_type == 'supervised':
    embeddings = get_latents(data_array_raw,encoder_file_path,ssl_encoder)
elif encoder_type == 'PCA' or encoder_type == 'NMF':
    embeddings = get_unsup_latents(data_array_raw,set_pca)

'Get linear readout scores'
X = embeddings

y = stats_blood[:,1]
reg = LinearRegression().fit(X,y)
blood_vsl_score = reg.score(embeddings,stats_blood[:,1])
print("Blood Vessel Score : {}".format(blood_vsl_score ))

y = stats_cell[:,1]
reg = LinearRegression().fit(X,y)
numb_cell = reg.score(embeddings,stats_cell[:,1])
print("Cell Count Score:{}".format(numb_cell))

y = stats_cell[:,2]
reg = LinearRegression().fit(X,y)
avg_dist_nn_cell = reg.score(embeddings,stats_cell[:,2])
print("Avg Cell Distance Score :{}".format(avg_dist_nn_cell ))

y = stats_cell[:,4]
reg = LinearRegression().fit(X,y)
cell_size = reg.score(embeddings,stats_cell[:,4])
print("Cell Size Score:{}".format(cell_size))

y = stats_axon[:,1]
reg = LinearRegression().fit(X,y)
axon_rslt = reg.score(embeddings,stats_axon[:,1])
print("Axon Score: {}".format(axon_rslt ))