Skip to content

Commit

Permalink
finish stubbing out tests for the data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
MekWarrior committed May 13, 2020
1 parent f8b4688 commit 9cad1b0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 42 deletions.
13 changes: 7 additions & 6 deletions caliban_toolbox/pre_annotation/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self,
imaging_types,
specimen_types,
compartments=None,
markers=['all'], # these and the following should be sets to prevent double 'all's etc
markers=['all'], # the following should be sets to prevent double 'alls etc
exp_ids=['all'],
sessions=['all'],
positions=['all'],
Expand Down Expand Up @@ -184,11 +184,11 @@ def _path_builder(self, root_path, list_of_dirs):
and verify that these new paths exist.
Args:
root_path (path):
list_of_dirs (list):
root_path (path): base path to add to
list_of_dirs (list): directory names to add to the base path
Returns:
list:
list: combined path of length equal to number of dirs in list_of_dirs
"""
new_paths = []
for item in list_of_dirs:
Expand All @@ -202,7 +202,8 @@ def _path_builder(self, root_path, list_of_dirs):
return new_paths

def _assemble_paths(self):
"""
"""Go through permuations of parameters and assemble paths that lead to the
directories of interest (containing a metadata json file) as well as img stacks
"""
# maybe a dictionary would be better here? need to map multiple tiff files to a data dir
# probably should be a class per dataset
Expand Down Expand Up @@ -458,4 +459,4 @@ def load_imagedata(self):

# predict on data
# need to have a dictionary of models to run
# curate-seg-track job
# curate-seg-track job
72 changes: 36 additions & 36 deletions caliban_toolbox/pre_annotation/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,52 +29,52 @@
from __future__ import division
from __future__ import print_function

import os

import numpy as np
import pandas as pd
import skimage as sk
import random

import pytest

from caliban_toolbox.pre_annotation import data_loader


def _get_dummy_inputs(object):
possible_data_type = [['2d', 'static'], ['2d', 'dynamic'], ['3d', 'static'], ['3d', 'dynamic']]
possible_imaging_types = [[]]
possible_specimen_types
possible_compartments=None
possible_markers=['all']
possible_exp_ids=['all']
possible_sessions=['all']
possible_positions=['all']
possible_file_type='.tif'
possible_data_type = random.choice([['2d', 'static'],
['2d', 'dynamic'],
['3d', 'static'],
['3d', 'dynamic']])
possible_imaging_types = random.choice([['fluo'], ['phase'], ['fluo', 'phase'], ['all']])
possible_specimen_types = random.choice([['HEK293'], ['HeLa'], ['HEK293', 'HeLa'], ['all']])
possible_compartments = random.choice([[None], ['nuclear'], ['nuclear', 'wholecell'], ['all']])
possible_markers = ['all']
possible_exp_ids = ['all']
possible_sessions = ['all']
possible_positions = ['all']
possible_file_type = '.tif'

loader_inputs = [possible_data_type,
possible_imaging_types,
possible_specimen_types,
possible_compartments,
possible_markers,
possible_exp_ids,
possible_sessions,
possible_positions,
possible_file_type]

return loader_inputs


class TestUniversalDataLoader(object): # pylint: disable=useless-object-inheritance

def test_simple(self):
loader_inputs = _get_dummy_inputs(self)
_ = data_loader.UniversalDataLoader(loader_inputs)

# test data with bad rank
with pytest.raises(ValueError):
data_loader.UniversalDataLoader(
np.random.random((32, 32, 1)),
np.random.randint(num_objects, size=(32, 32, 1)),
model=model)

# test mismatched x and y shape
with pytest.raises(ValueError):
data_loader.UniversalDataLoader(
np.random.random((3, 32, 32, 1)),
np.random.randint(num_objects, size=(2, 32, 32, 1)),
model=model)

# test bad features
with pytest.raises(ValueError):
data_loader.UniversalDataLoader(x, y, model=model, features=None)

# test bad data_format
with pytest.raises(ValueError):
data_loader.UniversalDataLoader(x, y, model=model, data_format='invalid')
# test with standard inputs
_ = data_loader.UniversalDataLoader(data_type=loader_inputs[0],
imaging_types=loader_inputs[1],
specimen_types=loader_inputs[2],
compartments=loader_inputs[3],
markers=loader_inputs[4],
exp_ids=loader_inputs[5],
sessions=loader_inputs[6],
positions=loader_inputs[7],
file_type=loader_inputs[8])

0 comments on commit 9cad1b0

Please sign in to comment.