# NNDataset for training neural networks
A NNDataset is a collection of MPPDatas (one for train/val/test). It is created using a parameter file defining its properties.

This notebooks shows:
- how to create a dataset from parameters
- how to iterate over an existing dataset
TODO for actual examples, need to handle downloading of files
TODO add more info text 
TODO change dataset?

+ example creation from params dict here!
+ after creation load NNDataset + show how to iterate over samples for training

In [1]:
import os
from miann.data import NNDataset

## Create NNDataset
TODO

In [2]:
data_params = {
    # name of the resulting dataset that is defined by these params (relative to DATA_DIR/datasets)
    'dataset_name': '184A1_test_dataset',
    # name of data config (registered in config.ini)
    'data_config': "NascentRNA", # TODO change to example data_config
    # where to read data from (relative to DATA_DIR defined in data config)
    'data_dirs': 
            [os.path.join('184A1_unperturbed', well) for well in ['I09', 'I11']] + \
            [os.path.join('184A1_meayamycin', well) for well in ['I12', 'I20']],
    'channels': [
        '00_DAPI','07_H2B','01_CDK9_pT186','03_CDK9','05_GTF2B','07_SETD1A','08_H3K4me3','09_SRRM2','10_H3K27ac',
        '11_KPNA2_MAX','12_RB1_pS807_S811','13_PABPN1','14_PCNA','15_SON','16_H3','17_HDAC3','19_KPNA1_MAX',
        '20_SP100','21_NCL','01_PABPC1','02_CDK7','03_RPS6','05_Sm','07_POLR2A','09_CCNT1','10_POL2RA_pS2', 
        '11_PML','12_YAP1','13_POL2RA_pS5','15_U2SNRNPB','18_NONO','20_ALYREF','21_COIL'
    ],
    # list of conditions. Should be defined in data config. 
    # The suffix '_one_hot' will convert the condition in a one-hot encoded vector.
    # Conditions are concatenated, except when they are defined as a list of lists. 
    # In this case the condition is defined as a pairwise combination of the conditions.
    'condition': ['perturbation_duration_one_hot', 'cell_cycle_one_hot'],
    'condition_kwargs': {
        'cond_params': {}
    },
    # train/val/test split
    'split_kwargs': {
        'train_frac': 0.8,
        'val_frac': 0.1,
    },
    'test_img_size': 225,
    # subset to objects with certain metadata.
    'subset': True,
    # kwargs to MPPData.subset() defining which object to subset to
    'subset_kwargs': {
        'frac': None, # special kwarg for random subsetting of objects
        'nona_condition': True,  # special kwarg for removing all objects with NAN condition
        'cell_cycle': 'NO_NAN'
    },
    # subsampling of pixels (only for train/val)
    'subsample': True,
    # kwargs for MPPData.subsample() defining the fraction of pixels to be sampled
    'subsample_kwargs': {
        'frac': None,
        'frac_per_obj': 0.05,
        'num': None,
        'num_per_obj': None,
    },
    # neighborhood information
    'neighborhood': True,
    'neighborhood_size': 3,
    # normalisation
    'normalise': True,
    'normalise_kwargs': {
        # background_value is column name in CHANNELS_METADATA, or list of floats per channel
        'background_value': 'mean_background',
        'percentile': 98.0,
        'rescale_values': [],
    },
    # make results reproducible
    'seed': 42,
}

## Use NNDataset

In [3]:
dataset_name = '184A1_all_frac005_neigh3_cond_pert-CC'
ds = NNDataset(dataset_name, data_config='NascentRNA')

In [3]:
# dataset has attributes x and y (NN input + output)
# x is either mpp or mpp+condition
x = ds.x('val', is_conditional=False)
print(x.shape)

x, c = ds.x('train', is_conditional=True)
print(x.shape, c.shape)

(78814, 3, 3, 33)
(622394, 3, 3, 33) (622394, 14)


In [4]:
# dataset has data attributes with train/val/test data and img attribute with val/test image data. 
# each split is represented as an MPPData object
print(ds.data['train'])

MPPData for NascentRNA (622394 mpps with shape (3, 3, 33) from 964 objects). Data keys: ['x', 'y', 'mpp', 'obj_ids', 'labels', 'conditions'].


In [7]:
# dataset can returrn a tf dataset for using with e.g. keras
tf_ds = ds.get_tf_dataset(split='train', is_conditional=True)
print(tf_ds)

for x,y in tf_ds.take(1):
    print(x)
    print(y)

[(tf.float32, tf.float32), tf.float32]
[(TensorShape([3, 3, 33]), TensorShape([14])), TensorShape([33])]
<FlatMapDataset shapes: (((3, 3, 33), (14,)), (33,)), types: ((tf.float32, tf.float32), tf.float32)>
(<tf.Tensor: shape=(3, 3, 33), dtype=float32, numpy=
array([[[0.15203735, 0.62895805, 0.03757562, 0.06181163, 0.65934986,
         0.11308727, 0.19545957, 0.08731578, 0.09365417, 0.10214146,
         0.14875232, 0.00492055, 0.25848264, 0.08706211, 0.3823452 ,
         0.13618195, 0.27064598, 0.21471222, 0.211353  , 0.14106484,
         0.4476323 , 0.        , 0.27191448, 0.4585612 , 0.45791107,
         0.09481694, 0.70633245, 0.04939254, 0.02446112, 0.01210488,
         0.03928334, 0.22237381, 0.35262993],
        [0.07494983, 0.9156723 , 0.1192574 , 0.01916671, 0.97870934,
         0.16651575, 0.24637985, 0.02988384, 0.09365417, 0.16761032,
         0.21968962, 0.00492055, 0.25848264, 0.15798756, 0.29212594,
         0.18086219, 0.30235702, 0.2597287 , 0.25311935, 0.09678983,
     

In [6]:
# dataset has fn for mapping channel orderings to channel ids TODO when do we need this?
ds.get_channel_ids(['16_H3', '09_CCNT1'])

[23, 10]