# Train and evaluate models
- experiment class + parameters
- training
- evaluation of val split (with clustering)


TODO: more detailled explanation of model classes?

In [1]:
import logging
logging.basicConfig(level=logging.INFO)

import miann.tl._experiment
import miann.tl._estimator
import miann.tl._evaluate
import miann.data._data

import importlib
importlib.reload(miann.tl._experiment)
importlib.reload(miann.tl._estimator)
importlib.reload(miann.tl._evaluate)
importlib.reload(miann.data._data)

from miann.constants import get_data_config, EXPERIMENT_DIR
from miann.utils import merged_config, load_config

data_config = get_data_config('NascentRNA')
print(EXPERIMENT_DIR)

INFO:numexpr.utils:NumExpr defaulting to 8 threads.


/Users/hannah.spitzer/projects/pelkmans/local_experiments/NascentRNA_new


## Experiment class handles config files
For training and evaluating models, an experiment_params file is used. This file contains several model/experiment parameters for easy training of several models at the same time. The parameter dictionaries contain several sections:
- experiment (where to save experiment)
- data (which dataset to use for training)
- model (model class definition)
- training (training hyperparameters)
- evaluation (evaluation on val/test split)
- cluster (clustering on val/test split)

The Experiment class is initialised from a parameter dictionary for one specific experiment and is passed to specific classes for training, evaluation, and clustering.

In [2]:
# get config for training
config = load_config("params/example_experiment_params.py")
exp_config = merged_config(config.base_config, config.variable_config[0])
exp_config

{'experiment': {'dir': 'test', 'name': 'VAE', 'save_config': True},
 'data': {'load_dataset': True,
  'data_config': 'NascentRNA',
  'dataset_name': '184A1_test_dataset',
  'output_channels': None},
 'model': {'model_cls': <ModelEnum.VAEModel: 'VAEModel'>,
  'model_kwargs': {'num_neighbors': 3,
   'num_channels': 34,
   'num_output_channels': 34,
   'latent_dim': 16,
   'encoder_conv_layers': [32],
   'encoder_conv_kernel_size': [1],
   'encoder_fc_layers': [32, 16],
   'decoder_fc_layers': []},
  'init_with_weights': False},
 'training': {'learning_rate': 0.001,
  'epochs': 10,
  'batch_size': 128,
  'loss': {'decoder': <LossEnum.SIGMA_MSE: 'sigma_vae_mse'>,
   'latent': <LossEnum.KL: 'kl_divergence'>},
  'metrics': {'decoder': <LossEnum.MSE_metric: 'mean_squared_error_metric'>,
   'latent': <LossEnum.KL: 'kl_divergence'>},
  'save_model_weights': True,
  'save_history': True,
  'overwrite_history': True},
 'evaluation': {'split': 'val',
  'predict_reps': ['latent', 'decoder'],
  'img

In [3]:
exp = miann.tl._experiment.Experiment(exp_config)

INFO:Experiment:Setting up experiment test/VAE
INFO:Experiment:Saving config to test/VAE/config.json


## NN training and prediction with Estimator
The Estimator handles model setup, training, and prediction


In [8]:
est = miann.tl._estimator.Estimator(exp)

INFO:VAEModel:Creating model
INFO:MPPData:Created new: MPPData for NascentRNA (246467 mpps with shape (3, 3, 34) from 1768 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (11848 mpps with shape (3, 3, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (14231 mpps with shape (3, 3, 34) from 101 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (1184845 mpps with shape (1, 1, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (1423116 mpps with shape (1, 1, 34) from 101 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].


In [9]:
est.train_model()

INFO:Estimator:Training model for 10 epochs


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


INFO:Estimator:Saving model to /Users/hannah.spitzer/projects/pelkmans/local_experiments/NascentRNA_new/test/VAE/weights_epoch010


Unnamed: 0_level_0,loss,decoder_loss,latent_loss,decoder_mean_squared_error,latent_kl_loss,val_loss,val_decoder_loss,val_latent_loss,val_decoder_mean_squared_error,val_latent_kl_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,-1632.197754,-1642.930908,10.732606,0.041395,10.732114,-2282.504639,-2291.129395,8.624417,0.020534,8.626506
1,-2384.553955,-2392.40625,7.853926,0.019779,7.853745,-2558.535889,-2565.875244,7.340002,0.01812,7.341238
2,-2517.866943,-2524.963623,7.099556,0.018591,7.099494,-2587.557617,-2594.43457,6.87666,0.017876,6.877465
3,-2539.678467,-2546.547852,6.870928,0.0184,6.8709,-2609.353027,-2616.270264,6.917495,0.017672,6.918108
4,-2551.162354,-2557.935791,6.77248,0.018306,6.772503,-2632.658936,-2639.449219,6.790499,0.017467,6.790955
5,-2558.330811,-2564.964111,6.63102,0.018243,6.631079,-2478.487549,-2485.154541,6.667218,0.018739,6.667554
6,-2561.404541,-2567.890381,6.487192,0.018226,6.487279,-2635.505127,-2642.008545,6.503323,0.017445,6.503577
7,-2562.494141,-2568.830322,6.334205,0.018219,6.334208,-2627.247314,-2633.653564,6.406629,0.017515,6.406843
8,-2562.961182,-2569.169678,6.205832,0.018206,6.205904,-2640.666016,-2646.894043,6.227886,0.017404,6.228083
9,-2563.543213,-2569.621582,6.079512,0.018205,6.079517,-2633.823486,-2640.162842,6.338755,0.017458,6.338951


## Predict val split and images with Predictor

In [8]:
pred = miann.tl._evaluate.Predictor(exp)

INFO:Predictor:Creating Predictor for test/VAE
INFO:VAEModel:Creating model
INFO:Estimator:Initializing model with weights from /Users/hannah.spitzer/projects/pelkmans/local_experiments/NascentRNA_new/test/VAE/weights_epoch010
INFO:MPPData:Created new: MPPData for NascentRNA (246467 mpps with shape (3, 3, 34) from 1768 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (11848 mpps with shape (3, 3, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (14231 mpps with shape (3, 3, 34) from 101 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (1184845 mpps with shape (1, 1, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (1423116 mpps with shape (1, 1, 34) from 101 objects).

In [5]:
pred.evaluate_model()

NameError: name 'pred' is not defined

## Cluster resulting latent space with Cluster

In [4]:
cl = miann.tl._evaluate.Cluster.from_exp_split(exp)

2021-11-07 16:14:12.924029: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO:VAEModel:Creating model
INFO:Estimator:Initializing model with weights from /Users/hannah.spitzer/projects/pelkmans/local_experiments/NascentRNA_new/test/VAE/weights_epoch010
INFO:MPPData:Created new: MPPData for NascentRNA (246467 mpps with shape (3, 3, 34) from 1768 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (11848 mpps with shape (3, 3, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (14231 mpps with shape (3, 3, 34) from 101 objects). Data keys: ['x', 'y', '

In [6]:
cl.create_clustering()

INFO:Cluster:Creating leiden clustering
OMP: Info #271: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
INFO:Cluster:Calculating umap
INFO:Cluster:Creating pynndescent index for latent
INFO:MPPData:Saving mpp data to /Users/hannah.spitzer/projects/pelkmans/local_experiments/NascentRNA_new/test/VAE/results_epoch010/val (keys: ['clustering', 'y', 'umap', 'x', 'obj_ids'])


In [5]:
# predict cluster images
_ = cl.predict_cluster_imgs(exp)

INFO:Predictor:Creating Predictor for test/VAE
INFO:VAEModel:Creating model
INFO:Estimator:Initializing model with weights from /Users/hannah.spitzer/projects/pelkmans/local_experiments/NascentRNA_new/test/VAE/weights_epoch010
INFO:MPPData:Created new: MPPData for NascentRNA (246467 mpps with shape (3, 3, 34) from 1768 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (11848 mpps with shape (3, 3, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (14231 mpps with shape (3, 3, 34) from 101 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (1184845 mpps with shape (1, 1, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (1423116 mpps with shape (1, 1, 34) from 101 objects).

Cannot read with memmap:  /Users/hannah.spitzer/projects/pelkmans/local_experiments/NascentRNA_new/test/VAE/results_epoch010/val/clustering.npy


INFO:Cluster:Projecting clustering to {len(mpp_data.x)} sampled
INFO:Cluster:processing chunk 0
INFO:MPPData:Saving mpp data to /Users/hannah.spitzer/projects/pelkmans/local_experiments/NascentRNA_new/test/VAE/results_epoch010/val_imgs (keys: ['clustering', 'y', 'x', 'obj_ids'])


## Plot results using ModelComparator

In [None]:
# TODO or have individual plotting functions that can also use with any mpp_data?
# TODO modelcomp should used these individual plotting fns, and just make comparison more convinient
# TODO maybe experiment can have fns for returning all kings of mpp datas? exp.get_clustered_data()
# exp.get_val_data() (with clustering etc)