# NNDataset for training neural networks
A NNDataset is a collection of MPPDatas (one for train/val/test). It is created using a parameter file defining its properties.

This notebooks shows:
- how to create a dataset from parameters
- how to iterate over an existing dataset
TODO for actual examples, need to handle downloading of files
TODO add more info text 
TODO change dataset?

+ example creation from params dict here!
+ after creation load NNDataset + show how to iterate over samples for training

In [2]:
import logging
logging.basicConfig(level=logging.INFO)

from miann.data import NNDataset
from miann.data import create_dataset
from miann.utils import load_config


## Create NNDataset
NNDataset is created with a config file that specifies from which folders to take the data and how to preprocess the data.

Below a dataset is created and saved to DATASET_DIR/184A1_test_dataset. Alternatively to this code, the NNDataset can also easily be created with the cli:
`python create_dataset.py params/example_data_params.py`

In [3]:
config = load_config("params/example_data_params.py")
config.data_params


{'dataset_name': '184A1_test_dataset',
 'data_config': 'NascentRNA',
 'data_dirs': ['184A1_unperturbed/I09',
  '184A1_unperturbed/I11',
  '184A1_meayamycin/I12',
  '184A1_meayamycin/I20'],
 'channels': ['01_CDK9_pT186',
  '01_PABPC1',
  '02_CDK7',
  '03_CDK9',
  '03_RPS6',
  '05_GTF2B',
  '05_Sm',
  '07_POLR2A',
  '07_SETD1A',
  '08_H3K4me3',
  '09_CCNT1',
  '09_SRRM2',
  '10_H3K27ac',
  '10_POL2RA_pS2',
  '11_KPNA2_MAX',
  '11_PML',
  '12_RB1_pS807_S811',
  '12_YAP1',
  '13_PABPN1',
  '13_POL2RA_pS5',
  '14_PCNA',
  '15_SON',
  '15_U2SNRNPB',
  '16_H3',
  '17_HDAC3',
  '17_SRSF2',
  '18_NONO',
  '19_KPNA1_MAX',
  '20_ALYREF',
  '20_SP100',
  '21_COIL',
  '21_NCL',
  '00_DAPI',
  '07_H2B'],
 'condition': ['perturbation_duration_one_hot', 'cell_cycle_one_hot'],
 'condition_kwargs': {'cond_params': {}},
 'split_kwargs': {'train_frac': 0.9, 'val_frac': 0.05},
 'test_img_size': 225,
 'subset': True,
 'subset_kwargs': {'frac': None,
  'nona_condition': True,
  'cell_cycle': 'NO_NAN'},
 'sub

In [4]:
create_dataset(config.data_params)

INFO:root:Creating train/val/test datasets with params:
INFO:root:{
    "dataset_name": "184A1_test_dataset",
    "data_config": "NascentRNA",
    "data_dirs": [
        "184A1_unperturbed/I09",
        "184A1_unperturbed/I11",
        "184A1_meayamycin/I12",
        "184A1_meayamycin/I20"
    ],
    "channels": [
        "01_CDK9_pT186",
        "01_PABPC1",
        "02_CDK7",
        "03_CDK9",
        "03_RPS6",
        "05_GTF2B",
        "05_Sm",
        "07_POLR2A",
        "07_SETD1A",
        "08_H3K4me3",
        "09_CCNT1",
        "09_SRRM2",
        "10_H3K27ac",
        "10_POL2RA_pS2",
        "11_KPNA2_MAX",
        "11_PML",
        "12_RB1_pS807_S811",
        "12_YAP1",
        "13_PABPN1",
        "13_POL2RA_pS5",
        "14_PCNA",
        "15_SON",
        "15_U2SNRNPB",
        "16_H3",
        "17_HDAC3",
        "17_SRSF2",
        "18_NONO",
        "19_KPNA1_MAX",
        "20_ALYREF",
        "20_SP100",
        "21_COIL",
        "21_NCL",
        "00_DAPI",


## Use NNDataset

In [5]:
dataset_name = '184A1_test_dataset'
ds = NNDataset(dataset_name, data_config='NascentRNA')

INFO:MPPData:Created new: MPPData for NascentRNA (246467 mpps with shape (3, 3, 34) from 1768 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (11848 mpps with shape (3, 3, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (14231 mpps with shape (3, 3, 34) from 101 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (1184845 mpps with shape (1, 1, 34) from 88 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].
INFO:MPPData:Created new: MPPData for NascentRNA (1423116 mpps with shape (1, 1, 34) from 101 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].


In [8]:
# dataset has attributes x and y (NN input + output)
# x is either mpp or mpp+condition
x = ds.x('val', is_conditional=False)
print(x.shape)

x, c = ds.x('train', is_conditional=True)
print(x.shape, c.shape)

(27999, 3, 3, 33)
(219054, 3, 3, 33) (219054, 14)


In [9]:
# dataset has data attributes with train/val/test data and img attribute with val/test image data. 
# each split is represented as an MPPData object
print(ds.data['train'])

MPPData for NascentRNA (219054 mpps with shape (3, 3, 33) from 1567 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].


In [10]:
# dataset can returrn a tf dataset for using with e.g. keras
tf_ds = ds.get_tf_dataset(split='train', is_conditional=True)
print(tf_ds)

for x,y in tf_ds.take(1):
    print(x)
    print(y)

<FlatMapDataset shapes: (((3, 3, 33), (14,)), (33,)), types: ((tf.float32, tf.float32), tf.float32)>
(<tf.Tensor: shape=(3, 3, 33), dtype=float32, numpy=
array([[[0.71139973, 0.63698417, 0.56148666, 0.96583056, 1.        ,
         0.67017716, 0.89312196, 0.16038842, 0.91077   , 0.78846544,
         0.943886  , 0.56684643, 0.7920429 , 0.05747046, 0.7373726 ,
         0.68158317, 0.9722624 , 0.3875561 , 0.7465759 , 0.2821306 ,
         0.6750523 , 0.5649243 , 0.85024047, 0.81841373, 0.71348804,
         0.7679738 , 0.8613879 , 0.7724328 , 0.871772  , 0.33624777,
         0.80515504, 0.72853166, 0.7531507 ],
        [0.7979798 , 0.63698417, 1.0626447 , 0.97722036, 0.89395314,
         0.5212249 , 1.1360266 , 0.16377395, 0.589542  , 1.1021202 ,
         0.8597149 , 0.45395002, 0.9426325 , 0.0255203 , 0.6969684 ,
         0.829684  , 0.965328  , 0.37308893, 0.80993193, 0.26581538,
         0.60241693, 0.89123106, 0.4841617 , 0.854731  , 0.7288369 ,
         0.6451364 , 0.5980249 , 0.563829

2021-11-06 16:25:32.124709: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-11-06 16:25:32.263399: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


In [11]:
# dataset has fn for mapping channel orderings to channel ids TODO when do we need this?
ds.get_channel_ids(['16_H3', '09_CCNT1'])

[14, 24]