diff --git a/rerpy/data.py b/rerpy/data.py index 417134f..f055737 100644 --- a/rerpy/data.py +++ b/rerpy/data.py @@ -131,6 +131,8 @@ def test_DataFormat(): assert df.ms_to_ticks(1000.1, round="up") == 1025 assert df.ms_to_ticks(1000.9, round="up") == 1025 + assert_raises(ValueError, df.ms_to_ticks, 1000, round="sideways") + assert df.ms_span_to_ticks(-1000, 1000) == (-1024, 1025) assert df.ms_span_to_ticks(-999.99, 999.99) == (-1023, 1024) assert df.ms_span_to_ticks(-1000.01, 1000.01) == (-1024, 1025) @@ -146,7 +148,6 @@ def test_DataFormat(): [0, 0.5, 0], [0, 0, 1]]) - from nose.tools import assert_raises assert_raises(ValueError, df.compute_symbolic_transform, "A2/2, A2/3") assert_raises(ValueError, df.compute_symbolic_transform, "A2/2 + 1") @@ -155,6 +156,8 @@ def __init__(self, data_format): self.data_format = data_format self._events = rerpy.events.Events() self._recspans = [] + self._lazy_recspans = [] + self._lazy_transforms = [] self.recspan_infos = [] def transform(self, matrix, exclude=[]): @@ -166,34 +169,93 @@ def transform(self, matrix, exclude=[]): raise ValueError("exclude= can only be specified if matrix= " "is a symbolic expression") matrix = np.asarray(matrix) - for i, recspan in enumerate(self._recspans): - new_data = np.dot(recspan, matrix.T) - self._recspans[i] = pandas.DataFrame(new_data, - columns=recspan.columns, - index=recspan.index) + for i in xrange(len(self._recspans)): + if self._recspans[i] is not None: + recspan = self._recspans[i] + new_data = np.dot(recspan, matrix.T) + self._recspans[i] = pandas.DataFrame(new_data, + columns=recspan.columns, + index=recspan.index) + else: + old_transform = self._lazy_transforms[i] + if old_transform is None: + old_transform = np.eye(self.data_format.num_channels) + self._lazy_transforms[i] = np.dot(matrix, old_transform) - def add_recspan(self, data, metadata): - data = np.asarray(data, dtype=np.float64) - if data.shape[1] != self.data_format.num_channels: - raise ValueError("wrong number of channels, array should have " - "shape (ticks, %s)" - % (self.data_format.num_channels,)) - ticks = data.shape[0] - recspan_id = len(self._recspans) + def _add_recspan_info(self, ticks, metadata): + recspan_id = len(self.recspan_infos) recspan_info = self._events.add_recspan_info(recspan_id, ticks, metadata) self.recspan_infos.append(recspan_info) - index = np.arange(ticks) * float(self.data_format.approx_sample_period_ms) + + def _decorate_recspan(self, data): + ticks = data.shape[0] + index = np.arange(ticks, dtype=float) + index *= self.data_format.approx_sample_period_ms df = pandas.DataFrame(data, columns=self.data_format.channel_names, index=index) - self._recspans.append(df) + return df + + def add_recspan(self, data, metadata): + data = np.asarray(data, dtype=np.float64) + if data.shape[1] != self.data_format.num_channels: + raise ValueError("wrong number of channels, array should have " + "shape (ticks, %s)" + % (self.data_format.num_channels,)) + ticks = data.shape[0] + self._add_recspan_info(ticks, metadata) + self._recspans.append(self._decorate_recspan(data)) + self._lazy_recspans.append(None) + self._lazy_transforms.append(None) + + def add_lazy_recspan(self, loader, ticks, metadata): + self._add_recspan_info(ticks, metadata) + self._recspans.append(None) + self._lazy_recspans.append(loader) + self._lazy_transforms.append(None) + + def add_dataset(self, dataset): + # Metadata + if self.data_format != dataset.data_format: + raise ValueError("data format mismatch") + # Recspans + our_recspan_id_base = len(self._recspans) + for recspan_info in dataset.recspan_infos: + self._add_recspan_info(recspan_info.ticks, dict(recspan_info)) + self._recspans += dataset._recspans + self._lazy_recspans += dataset._lazy_recspans + self._lazy_transforms += dataset._lazy_transforms + # Events + for their_event in dataset.events_query(): + self.add_event(their_event.recspan_id + our_recspan_id_base, + their_event.start_tick, + their_event.stop_tick, + dict(their_event)) # We act like a sequence of recspan data objects def __len__(self): return len(self._recspans) + def raw_slice(self, recspan_id, start_tick, stop_tick): + if start_tick < 0 or stop_tick < 0: + raise IndexError("only positive indexes allowed") + ticks = stop_tick - start_tick + recspan = self._recspans[recspan_id] + if recspan is not None: + result = np.asarray(recspan.iloc[start_tick:stop_tick, :]) + else: + lr = self._lazy_recspans[recspan_id] + lazy_data = lr.get_slice(start_tick, stop_tick) + transform = self._lazy_transforms[recspan_id] + if transform is not None: + lazy_data = np.dot(lazy_data, transform.T) + result = lazy_data + if result.shape[0] != ticks: + raise IndexError("slice spans missing data") + return result + def __getitem__(self, key): if not isinstance(key, int) and hasattr(key, "__index__"): key = key.__index__() @@ -201,7 +263,12 @@ def __getitem__(self, key): raise TypeError("Dataset indexing allows only a single integer " "(no slicing or other fanciness!)") # May raise IndexError, which is what we want: - return self._recspans[key] + recspan = self._recspans[key] + if recspan is None: + ticks = self.recspan_infos[key].ticks + raw = self.raw_slice(key, 0, ticks) + recspan = self._decorate_recspan(raw) + return recspan def __iter__(self): for i in xrange(len(self)): @@ -357,21 +424,6 @@ def epochs_ticks(self, event_query, start_tick, stop_tick, major_axis=time_array, minor_axis=self.data_format.channel_names) - def add_dataset(self, dataset): - # Metadata - if self.data_format != dataset.data_format: - raise ValueError("data format mismatch") - # Recspans - our_recspan_id_base = len(self._recspans) - for recspan, recspan_info in zip(dataset, dataset.recspan_infos): - self.add_recspan(recspan, dict(recspan_info)) - # Events - for their_event in dataset.events_query(): - self.add_event(their_event.recspan_id + our_recspan_id_base, - their_event.start_tick, - their_event.stop_tick, - dict(their_event)) - def merge_df(self, df, on, restrict=None): # 'on' is like {df_colname: event_key} # or just [colname] diff --git a/rerpy/io/erpss.py b/rerpy/io/erpss.py index e2c10c1..f76f37b 100644 --- a/rerpy/io/erpss.py +++ b/rerpy/io/erpss.py @@ -261,7 +261,7 @@ def read_next_chunk(self, return_data): # Check for EOF: if not buf: return None - codes = np.fromstring(buf[:512], dtype=np.uint16) + codes = np.fromstring(buf[:512], dtype=" self._recspan_stop: + raise IndexError("attempt to index beyond end of recspan") output = np.empty((stop_tick - start_tick, self._nchans), dtype=self._dtype) cursor = 0 @@ -333,7 +341,7 @@ def get_slice(self, start_tick, stop_tick): tick = chunk_number * 256 if tick >= stop_tick: break - data = self._chunk_fetcher.get_chunk(chunk_number) + data = self._fetcher.get_chunk(chunk_number) data.resize((256, self._nchans)) low = max(tick, start_tick) high = min(tick + 256, stop_tick) @@ -343,18 +351,36 @@ def get_slice(self, start_tick, stop_tick): chunk_number += 1 return output +def test_LazyRecspan(): + from nose.tools import assert_raises + from rerpy.test import test_data_path + for suffix in ["crw", "raw"]: + (fetcher, hz, channames, codes, data, info) = read_raw( + open(test_data_path("erpss/tiny-complete.%s" % (suffix,)), "rb"), + "u2", True) + # This fake recspan is chosen to cover part of the first and last + # chunks, plus the entire middle chunk. It's exactly 512 samples long. + lr = LazyRecspan(fetcher, "u2", len(channames), 128, 640) + assert_raises(IndexError, lr.get_slice, 0, 513) + for (start, stop) in [(0, 512), (10, 20), (256, 266), (500, 510), + (120, 130)]: + assert np.all(lr.get_slice(start, stop) + == data[128 + start:128 + stop]) + def assert_files_match(p1, p2): - (_, hz1, channames1, codes1, data1, info1) = read_raw(open(p1), "u2", True) + (_, hz1, channames1, codes1, data1, info1) = read_raw(open(p1, "rb"), "u2", True) for (p, load_data) in [(p1, False), (p2, True), (p2, False)]: (fetcher2, hz2, channames2, codes2, data2, info2 - ) = read_raw(open(p), "u2", load_data) + ) = read_raw(open(p, "rb"), "u2", load_data) assert hz1 == hz2 assert (channames1 == channames2).all() assert (codes1 == codes2).all() if not load_data: assert data2 is None - loader = DemandLoader(fetcher2, "u2", len(channames2)) - data2 = loader.get_slice(0, len(codes2)) + # Slight abuse, pretend that there's one recspan that has the whole + # file + loader2 = LazyRecspan(fetcher2, "u2", len(channames2), 0, len(codes2)) + data2 = loader2.get_slice(0, len(codes2)) assert (data1 == data2).all() for k in set(info1.keys() + info2.keys()): if k != "erpss_raw_header": @@ -375,7 +401,7 @@ def test_read_raw_on_test_data(): def test_64bit_channel_names(): from rerpy.test import test_data_path - stream = open(test_data_path("erpss/two-chunks-64chan.raw")) + stream = open(test_data_path("erpss/two-chunks-64chan.raw"), "rb") (_, hz, channel_names, codes, data, info) = read_raw(stream, int, False) # "Correct" channel names as listed by headinfo(1): assert (channel_names == @@ -468,12 +494,8 @@ def t(data, expected): ) t(data, expected) -# XX someday should fix this so that it has the option to delay reading the -# actual data until needed (to avoid the giant memory overhead of loading in -# lots of data sets together). The way to do it for crw files is just to read -# through the file without decompressing to find where each block is located -# on disk, and then we can do random access after we know that. def load_erpss(raw, log, calibration_events="condition == 0", + preload=True, calibrate=False, calibrate_half_width_ticks=5, calibrate_low_cursor_time=None, @@ -493,7 +515,7 @@ def load_erpss(raw, log, calibration_events="condition == 0", log = maybe_open(log) (fetcher, hz, channel_names, raw_codes, data, header_metadata - ) = read_raw(raw, dtype, True) + ) = read_raw(raw, dtype, preload) metadata.update(header_metadata) if calibrate: units = "uV" @@ -501,6 +523,8 @@ def load_erpss(raw, log, calibration_events="condition == 0", units = "RAW" data_format = DataFormat(hz, units, channel_names) + total_ticks = raw_codes.shape[0] + raw_log_events = read_log(log) expanded_log_codes = np.zeros(raw_codes.shape, dtype=int) try: @@ -533,14 +557,20 @@ def load_erpss(raw, log, calibration_events="condition == 0", break_ticks += 1 span_edges = np.concatenate(([0], break_ticks)) assert span_edges[0] == 0 - assert span_edges[-1] == data.shape[0] + assert span_edges[-1] == total_ticks span_slices = [slice(span_edges[i], span_edges[i + 1]) for i in xrange(len(span_edges) - 1)] dataset = Dataset(data_format) for span_slice in span_slices: - dataset.add_recspan(data[span_slice, :], metadata) + if preload: + dataset.add_recspan(data[span_slice, :], metadata) + else: + lr = LazyRecspan(fetcher, dtype, len(channel_names), + span_slice.start, span_slice.stop) + dataset.add_lazy_recspan(lr, span_slice.stop - span_slice.start, + metadata) span_starts = [s.start for s in span_slices] recspan_ids = [] @@ -612,105 +642,124 @@ def test_load_erpss(): # are supposed to be reserved for special stuff and deleted events, but # it happens the file I was using as a basis violated this rule. Oh # well. - dataset = load_erpss(test_data_path("erpss/tiny-complete.crw"), - test_data_path("erpss/tiny-complete.log")) - assert len(dataset) == 2 - assert dataset[0].shape == (512, 32) - assert dataset[1].shape == (256, 32) - - assert dataset.data_format.exact_sample_rate_hz == 250 - assert dataset.data_format.units == "RAW" - assert list(dataset.data_format.channel_names) == [ - "lle", "lhz", "MiPf", "LLPf", "RLPf", "LMPf", "RMPf", "LDFr", "RDFr", - "LLFr", "RLFr", "LMFr", "RMFr", "LMCe", "RMCe", "MiCe", "MiPa", "LDCe", - "RDCe", "LDPa", "RDPa", "LMOc", "RMOc", "LLTe", "RLTe", "LLOc", "RLOc", - "MiOc", "A2", "HEOG", "rle", "rhz", - ] - - for recspan_info in dataset.recspan_infos: - assert recspan_info["raw_file"].endswith("tiny-complete.crw") - assert recspan_info["log_file"].endswith("tiny-complete.log") - assert recspan_info["experiment"] == "brown-1" - assert recspan_info["subject"] == "Subject p3 2008-08-20" - assert recspan_info["odelay"] == 8 - assert len(recspan_info["erpss_raw_header"]) == 512 - - assert dataset.recspan_infos[0].ticks == 512 - assert dataset.recspan_infos[1].ticks == 256 - assert dataset.recspan_infos[1]["deleted"] - - assert len(dataset.events()) == 14 - # 2 are calibration events - assert len(dataset.events("has code")) == 12 - for ev in dataset.events("has code"): - assert ev["condition"] in (64, 65) - assert ev["flag"] == 0 - assert not ev["flag_data_error"] - assert not ev["flag_polinv"] - assert not ev["flag_rejected"] - for ev in dataset.events("calibration_pulse"): - assert dict(ev) == {"calibration_pulse": True} - def check_ticks(query, recspan_ids, start_ticks): - events = dataset.events(query) - assert len(events) == len(recspan_ids) == len(start_ticks) - for ev, recspan_id, start_tick in zip(events, recspan_ids, start_ticks): - assert ev.recspan_id == recspan_id - assert ev.start_tick == start_tick - assert ev.stop_tick == start_tick + 1 - - check_ticks("condition == 64", - [0] * 8, [21, 221, 304, 329, 379, 458, 483, 511]) - check_ticks("condition == 65", - [1] * 4, - [533 - 512, 733 - 512, 762 - 512, 767 - 512]) - check_ticks("calibration_pulse", [0, 0], [250, 408]) - - # check calibration_events option - dataset2 = load_erpss(test_data_path("erpss/tiny-complete.crw"), - test_data_path("erpss/tiny-complete.log"), - calibration_events="condition == 65") - assert len(dataset2.events("condition == 65")) == 0 - assert len(dataset2.events("condition == 0")) == 2 - assert len(dataset2.events("calibration_pulse")) == 4 - - # check calibration - # idea: if calibration works, then the "calibration erp" will have been - # set to be the same size as whatever we told it to be. - dataset_cal = load_erpss(test_data_path("erpss/tiny-complete.crw"), + for preload in [True, False]: + dataset = load_erpss(test_data_path("erpss/tiny-complete.crw"), test_data_path("erpss/tiny-complete.log"), - calibration_events="condition == 65", - calibrate=True, - calibrate_half_width_ticks=2, - calibrate_low_cursor_time=-16, - calibrate_high_cursor_time=21, - calibrate_pulse_size=12.34, - calibrate_polarity=-1) - assert dataset_cal.data_format.units == "uV" - # -16 ms +/-2 ticks = -24 to -8 ms - low_cal = dataset_cal.rerp("calibration_pulse", -24, -8, "1", - all_or_nothing=True, - overlap_correction=False) - # 21 ms rounds to 20 ms, +/-2 ticks for the window = 12 to 28 ms - high_cal = dataset_cal.rerp("calibration_pulse", 12, 28, "1", - all_or_nothing=True, - overlap_correction=False) - low = low_cal.betas["Intercept"].mean(axis=0) - high = high_cal.betas["Intercept"].mean(axis=0) - assert np.allclose(high - low, -1 * 12.34) - - # check that we can load from file handles (not sure if anyone cares but - # hey you never know...) - assert len(load_erpss(open(test_data_path("erpss/tiny-complete.crw")), - open(test_data_path("erpss/tiny-complete.log")))) == 2 - - # check that code/raw mismatch is detected - from nose.tools import assert_raises - for bad in ["bad-code", "bad-tick", "bad-tick2"]: - assert_raises(ValueError, - load_erpss, - test_data_path("erpss/tiny-complete.crw"), - test_data_path("erpss/tiny-complete.%s.log" % (bad,))) - # But if the only mismatch is an event that is "deleted" (sign bit set) in - # the log file, but not in the raw file, then that is okay: - load_erpss(test_data_path("erpss/tiny-complete.crw"), - test_data_path("erpss/tiny-complete.code-deleted.log")) + preload=preload) + assert len(dataset) == 2 + assert dataset[0].shape == (512, 32) + assert dataset[1].shape == (256, 32) + + assert dataset.data_format.exact_sample_rate_hz == 250 + assert dataset.data_format.units == "RAW" + assert list(dataset.data_format.channel_names) == [ + "lle", "lhz", "MiPf", "LLPf", "RLPf", "LMPf", "RMPf", "LDFr", "RDFr", + "LLFr", "RLFr", "LMFr", "RMFr", "LMCe", "RMCe", "MiCe", "MiPa", "LDCe", + "RDCe", "LDPa", "RDPa", "LMOc", "RMOc", "LLTe", "RLTe", "LLOc", "RLOc", + "MiOc", "A2", "HEOG", "rle", "rhz", + ] + + for recspan_info in dataset.recspan_infos: + assert recspan_info["raw_file"].endswith("tiny-complete.crw") + assert recspan_info["log_file"].endswith("tiny-complete.log") + assert recspan_info["experiment"] == "brown-1" + assert recspan_info["subject"] == "Subject p3 2008-08-20" + assert recspan_info["odelay"] == 8 + assert len(recspan_info["erpss_raw_header"]) == 512 + + assert dataset.recspan_infos[0].ticks == 512 + assert dataset.recspan_infos[1].ticks == 256 + assert dataset.recspan_infos[1]["deleted"] + + assert len(dataset.events()) == 14 + # 2 are calibration events + assert len(dataset.events("has code")) == 12 + for ev in dataset.events("has code"): + assert ev["condition"] in (64, 65) + assert ev["flag"] == 0 + assert not ev["flag_data_error"] + assert not ev["flag_polinv"] + assert not ev["flag_rejected"] + for ev in dataset.events("calibration_pulse"): + assert dict(ev) == {"calibration_pulse": True} + def check_ticks(query, recspan_ids, start_ticks): + events = dataset.events(query) + assert len(events) == len(recspan_ids) == len(start_ticks) + for ev, recspan_id, start_tick in zip(events, recspan_ids, start_ticks): + assert ev.recspan_id == recspan_id + assert ev.start_tick == start_tick + assert ev.stop_tick == start_tick + 1 + + check_ticks("condition == 64", + [0] * 8, [21, 221, 304, 329, 379, 458, 483, 511]) + check_ticks("condition == 65", + [1] * 4, + [533 - 512, 733 - 512, 762 - 512, 767 - 512]) + check_ticks("calibration_pulse", [0, 0], [250, 408]) + + # check calibration_events option + dataset2 = load_erpss(test_data_path("erpss/tiny-complete.crw"), + test_data_path("erpss/tiny-complete.log"), + preload=preload, + calibration_events="condition == 65") + assert len(dataset2.events("condition == 65")) == 0 + assert len(dataset2.events("condition == 0")) == 2 + assert len(dataset2.events("calibration_pulse")) == 4 + + # check calibration + # idea: if calibration works, then the "calibration erp" will have been + # set to be the same size as whatever we told it to be. + dataset_cal = load_erpss(test_data_path("erpss/tiny-complete.crw"), + test_data_path("erpss/tiny-complete.log"), + preload=preload, + calibration_events="condition == 65", + calibrate=True, + calibrate_half_width_ticks=2, + calibrate_low_cursor_time=-16, + calibrate_high_cursor_time=21, + calibrate_pulse_size=12.34, + calibrate_polarity=-1) + assert dataset_cal.data_format.units == "uV" + # -16 ms +/-2 ticks = -24 to -8 ms + low_cal = dataset_cal.rerp("calibration_pulse", -24, -8, "1", + all_or_nothing=True, + overlap_correction=False) + # 21 ms rounds to 20 ms, +/-2 ticks for the window = 12 to 28 ms + high_cal = dataset_cal.rerp("calibration_pulse", 12, 28, "1", + all_or_nothing=True, + overlap_correction=False) + low = low_cal.betas["Intercept"].mean(axis=0) + high = high_cal.betas["Intercept"].mean(axis=0) + assert np.allclose(high - low, -1 * 12.34) + + # check that we can load from file handles (not sure if anyone cares but + # hey you never know...) + crw = open(test_data_path("erpss/tiny-complete.crw"), "rb") + log = open(test_data_path("erpss/tiny-complete.log"), "rb") + assert len(load_erpss(crw, log, preload=preload)) == 2 + + # check that code/raw mismatch is detected + from nose.tools import assert_raises + for bad in ["bad-code", "bad-tick", "bad-tick2"]: + assert_raises(ValueError, + load_erpss, + test_data_path("erpss/tiny-complete.crw"), + test_data_path("erpss/tiny-complete.%s.log" % (bad,)), + preload=preload) + # But if the only mismatch is an event that is "deleted" (sign bit + # set) in the log file, but not in the raw file, then that is okay: + load_erpss(test_data_path("erpss/tiny-complete.crw"), + test_data_path("erpss/tiny-complete.code-deleted.log"), + preload=preload) + + # Compare preload to no-preload directly + pre = load_erpss(test_data_path("erpss/tiny-complete.crw"), + test_data_path("erpss/tiny-complete.log"), + preload=True) + lazy = load_erpss(test_data_path("erpss/tiny-complete.crw"), + test_data_path("erpss/tiny-complete.log"), + preload=False) + from pandas.util.testing import assert_frame_equal + assert len(pre) == len(lazy) + for pre_recspan, lazy_recspan in zip(pre, lazy): + assert_frame_equal(pre_recspan, lazy_recspan) diff --git a/rerpy/rerp.py b/rerpy/rerp.py index 4e9d41d..0f1ad05 100644 --- a/rerpy/rerp.py +++ b/rerpy/rerp.py @@ -1238,10 +1238,9 @@ def _fit_by_epoch(dataset, analysis_subspans, rerps): Xs_Ys_by_rerp = {rerp: ([], []) for rerp in rerps} for epoch in epochs: this_X = epoch.design_row - recspan = dataset[epoch.recspan_id] - y_data = recspan.iloc[epoch.start_tick:epoch.stop_tick, :] - ticks = epoch.stop_tick - epoch.start_tick - this_Y = np.asarray(y_data).reshape((1, ticks * channels)) + y_data = dataset.raw_slice(epoch.recspan_id, + epoch.start_tick, epoch.stop_tick) + this_Y = y_data.reshape((1, -1)) Xs, Ys = Xs_Ys_by_rerp[epoch.rerp] Xs.append(this_X) Ys.append(this_Y) @@ -1312,8 +1311,8 @@ def _fit_continuous(dataset, analysis_subspans, rerps, log_stream): rows = 0 with ProgressBar(len(analysis_subspans), stream=log_stream) as progress_bar: for subspan in analysis_subspans: - recspan = dataset[subspan.start[0]] - data = np.asarray(recspan.iloc[subspan.start[1]:subspan.stop[1], :]) + data = dataset.raw_slice(subspan.start[0], + subspan.start[1], subspan.stop[1]) rows += data.shape[0] nnz = 0 for epoch in subspan.epochs: diff --git a/rerpy/test_data.py b/rerpy/test_data.py index 9ede428..f1ece58 100644 --- a/rerpy/test_data.py +++ b/rerpy/test_data.py @@ -8,15 +8,27 @@ from rerpy.data import Dataset, DataFormat +class FakeLazyRecspan(object): + def __init__(self, data): + self._data = data + + def get_slice(self, start, stop): + return self._data[start:stop, :] + def mock_dataset(num_channels=4, num_recspans=4, ticks_per_recspan=100, - hz=250): + hz=250, lazy="mixed"): + assert lazy in ["all", "mixed", "none"] data_format = DataFormat(hz, "uV", ["MOCK%s" % (i,) for i in xrange(num_channels)]) dataset = Dataset(data_format) r = np.random.RandomState(0) for i in xrange(num_recspans): data = r.normal(size=(ticks_per_recspan, num_channels)) - dataset.add_recspan(data, {}) + if lazy == "all" or (lazy == "mixed" and i % 2 == 0): + lr = FakeLazyRecspan(data) + dataset.add_lazy_recspan(lr, ticks_per_recspan, {}) + else: + dataset.add_recspan(data, {}) return dataset def test_Dataset(): @@ -29,8 +41,10 @@ def test_Dataset(): assert list(dataset) == [] dataset.add_recspan(np.ones((10, 2)) * 0, {"a": 0}) - dataset.add_recspan(np.ones((20, 2)) * 1, {"a": 1}) - dataset.add_recspan(np.ones((30, 2)) * 0, {"a": 2}) + dataset.add_lazy_recspan(FakeLazyRecspan(np.ones((20, 2)) * 1), + 20, {"a": 1}) + dataset.add_lazy_recspan(FakeLazyRecspan(np.ones((30, 2)) * 0), + 30, {"a": 2}) dataset.add_recspan(np.ones((40, 2)) * 1, {"a": 3}) assert len(dataset) == 4 @@ -55,6 +69,17 @@ def t(ds, recspan_id, expected_values=None): local_recspan_id = recspan_id % 2 expected_values = local_recspan_id assert np.allclose(recspan, expected_values) + assert np.allclose(ds.raw_slice(recspan_id, 0, recspan.shape[0]), + recspan) + assert_raises(IndexError, + ds.raw_slice, recspan_id, -1, 10) + assert_raises(IndexError, + ds.raw_slice, recspan_id, 10, -1) + assert ds.raw_slice(recspan_id, 2, 2).shape == (0, 2) + assert np.all(ds.raw_slice(recspan_id, 2, 5) + == recspan.iloc[2:5, :]) + assert_raises(IndexError, + ds.raw_slice, recspan_id, 0, 200) assert ds.recspan_infos[recspan_id]["a"] == recspan_id assert ds.recspan_infos[recspan_id].ticks == expected_ticks @@ -104,6 +129,9 @@ def t(ds, recspan_id, expected_values=None): for i in xrange(4): assert_frame_equal(recspans[i], dataset[i]) + # Smoke test + repr(dataset) + def test_Dataset_add_recspan(): dataset = mock_dataset(num_channels=2, num_recspans=4) dataset.add_recspan([[1, 2], [3, 4], [5, 6]], {"a": 31337})