Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New BatchSchema class to generate patch selectors and combine into batch selectors #132

Merged
merged 32 commits into from
Jan 3, 2023

Conversation

maxrjones
Copy link
Member

@maxrjones maxrjones commented Dec 2, 2022

Description of proposed changes

This PR is a major overhaul of xbatcher’s internals. Users should not notice any external changes, apart from bug fixes and some performance differences.

New BatchSchema class

In v0.2.0, information about the batch_selectors was stored as a dict in the ._batch_selectors attribute of the BatchGenerator class. This PR introduces BatchSchema which contains all necessary information about the batch selectors. BatchGenerator now creates an instance of BatchSchema, which handles the creation of selectors object. The purpose of BatchGenerator is to create batches (xarray Datasets/DataArrays) from those selector objects. This opens up possibilities including serializing/deserializing the BatchSchema (e.g., BatchSchema.to_json() and BatchSchema.from_json()), caching BatchSchema objects separate from caching batches, and relatedly applying one BatchSchema instance to multiple xarray datasets.

The representation of the batch selectors (BatchSchema.selectors) is the same regardless of whether concat_input_dimsis True or False(in contrast to v0.2.0 where the type varied based on that parameter). The selectors object is still a dict with keys representing the batch index and values representing the batch selector. The batch selector is a list of dicts, where the keys are the dimensions included in batch_dims and/or input_dims and the values are slices used to index the xarray dataset/dataarray. It’s simplest to refer to the subset of data created by each of these dictionaries as a patch. If concat_input_dims is False, each list has length 1 (i.e., 1 patch per batch). If concat_input_dims is True, the lists can contain multiple patches, which will be concatenated to form a batch by the generator.

Bug fixes

Handling batch_dims (#131, #121)

As explained #121 (comment), this issue stemmed from information about the batch_dims indices not being stored in the selectors object. After this PR, each patch contains information about all dims that are indexed on (i.e., any dim included in input_dims or batch_dims). Those patches are then combined into batches based on the number of items in the list for each batch selector.

Handling of overlapping samples

In v.2.0, the batch_dims were sliced first and subsequently index_dims were sliced. This had the consequence of missing away any patch that overlapped two batches. Now, the smallest patches created by the combination of batch_dims and input_dims are defined first and these are then combined into batches based on the starting index (see simple example below).

image

Performance changes

Running the benchmarks from #140 on main, 4d8e2c84a2d405e237f60f1df5286dd766e06ff0, and this PR revealed a mix of performance increases and decreases. Based on a quick investigation, I think that the regressions mostly relate to loading each batch individually, whereas before this PR if a dimension was included in batch_dims and input_dims, the data associated with the slice from batch_dims would be loaded even if concat_input_dims was False and that slice shouldn’t matter. Since loading more data improves performance, those benchmarks are now slower. Also, the handling of overlapping samples means there are more patches which will impact performance. In contrast, the benchmarks affected by #121 are now much faster.

$ asv continuous main batch-dims-bug
before           after         ratio
     [c0dd0029]       [9d0fcb05]
     <main>           <batch-dims-bug>
+     2.59±0.05ms       21.9±0.2ms     8.47  benchmarks.Generator.time_batch_preload(True)
+      82.9±0.4μs        112±0.5μs     1.35  benchmarks.TorchLoader.time_iterable_dataset
+      85.8±0.4μs        114±0.4μs     1.33  benchmarks.TorchLoader.time_map_dataset
+     16.6±0.05ms       19.3±0.2ms     1.16  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {}, {'x': 2})
+     10.4±0.08ms      12.0±0.07ms     1.16  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {}, {})
+        5.77±0ms      6.65±0.04ms     1.15  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {}, {'x': 1})
+      63.9±0.7ms      73.4±0.08ms     1.15  benchmarks.Accessor.time_accessor_input_dim({'x': 2, 'y': 2})
+      32.1±0.1ms      36.9±0.07ms     1.15  benchmarks.Accessor.time_accessor_input_dim({'x': 4, 'y': 2})
+     6.36±0.05ms      7.29±0.07ms     1.15  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {}, {'x': 2})
+     5.25±0.01ms      6.01±0.01ms     1.14  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {}, {})
+     10.4±0.07ms      11.9±0.07ms     1.14  benchmarks.Generator.time_batch_concat({'x': 5, 'y': 5}, False)
+     12.6±0.06ms       14.3±0.1ms     1.14  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {}, {'x': 1})
-     18.8±0.08ms      7.32±0.02ms     0.39  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 30}, {'x': 2})
-      49.9±0.3ms      19.1±0.06ms     0.38  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 30}, {'x': 2})
-     31.4±0.09ms      12.1±0.09ms     0.38  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 30}, {})
-     15.8±0.04ms       6.05±0.1ms     0.38  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 30}, {})
-     17.3±0.07ms      6.58±0.04ms     0.38  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 30}, {'x': 1})
-      37.7±0.1ms       14.4±0.1ms     0.38  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 30}, {'x': 1})
-      32.9±0.2ms      11.3±0.04ms     0.34  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 30}, {'x': 1})
-      43.9±0.2ms      15.1±0.02ms     0.34  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 30}, {'x': 2})
-     27.4±0.03ms      9.38±0.03ms     0.34  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 30}, {})
-     13.8±0.04ms      4.70±0.03ms     0.34  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 30}, {})
-     15.2±0.04ms      5.19±0.02ms     0.34  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 30}, {'x': 1})
-     16.6±0.07ms      5.67±0.02ms     0.34  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 30}, {'x': 2})
-     26.1±0.07ms      6.14±0.04ms     0.23  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 20}, {})
-      31.5±0.4ms      7.32±0.06ms     0.23  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 20}, {'x': 2})
-      62.6±0.2ms      14.5±0.07ms     0.23  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 20}, {'x': 1})
-     28.8±0.04ms      6.66±0.03ms     0.23  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 20}, {'x': 1})
-      83.6±0.3ms       19.2±0.1ms     0.23  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 20}, {'x': 2})
-      52.4±0.3ms       12.0±0.1ms     0.23  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 20}, {})
-      25.2±0.4ms      5.25±0.03ms     0.21  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 20}, {'x': 1})
-      72.6±0.6ms       15.1±0.1ms     0.21  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 20}, {'x': 2})
-      27.7±0.1ms      5.69±0.03ms     0.21  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 20}, {'x': 2})
-      45.7±0.5ms      9.37±0.08ms     0.20  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 20}, {})
-      55.2±0.2ms      11.3±0.03ms     0.20  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 20}, {'x': 1})
-     23.2±0.04ms      4.69±0.02ms     0.20  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 20}, {})
-         676±3ms        130±0.5ms     0.19  benchmarks.Generator.time_batch_concat_4d({'x': 5}, {'x': 10, 'y': 10}, True)
-       135±0.5ms      25.9±0.09ms     0.19  benchmarks.Generator.time_batch_concat_4d({'x': 5}, {'x': 10}, True)
-         1.50±0s        166±0.3ms     0.11  benchmarks.Generator.time_batch_concat_4d({'x': 5, 'y': 5}, {'x': 10}, True)
-       120±0.6ms       12.5±0.1ms     0.10  benchmarks.Generator.time_batch_concat_4d({'x': 5}, {'x': 10}, False)
-         1.02±0s        105±0.4ms     0.10  benchmarks.Generator.time_batch_concat_4d({'x': 5, 'y': 5}, {'x': 10}, False)
-         598±4ms       61.0±0.1ms     0.10  benchmarks.Generator.time_batch_concat_4d({'x': 5}, {'x': 10, 'y': 10}, False)
-      7.57±0.03s        224±0.4ms     0.03  benchmarks.Generator.time_batch_concat_4d({'x': 5, 'y': 5}, {'x': 10, 'y': 10}, True)
-      5.07±0.02s        104±0.5ms     0.02  benchmarks.Generator.time_batch_concat_4d({'x': 5, 'y': 5}, {'x': 10, 'y': 10}, False)

SOME BENCHMARKS HAVE CHANGED SIGNIFICANTLY.
PERFORMANCE DECREASED.

$ asv continuous 4d8e2c84a2d405e237f60f1df5286dd766e06ff0 batch-dims-bug
before           after         ratio
     [4d8e2c84]       [9d0fcb05]
     <v0.2.0~6>       <batch-dims-bug>
+     2.49±0.04ms       22.1±0.3ms     8.86  benchmarks.Generator.time_batch_preload(True)
+      32.6±0.4μs          111±1μs     3.41  benchmarks.TorchLoader.time_iterable_dataset
+      33.3±0.1μs        112±0.8μs     3.37  benchmarks.TorchLoader.time_map_dataset
+     4.93±0.02ms      7.32±0.06ms     1.48  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 30}, {'x': 2})
+     4.93±0.03ms      6.68±0.05ms     1.35  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 30}, {'x': 1})
+      14.2±0.1ms      19.1±0.07ms     1.34  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 30}, {'x': 2})
+     10.9±0.04ms      14.4±0.09ms     1.32  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 20}, {'x': 1})
+     5.60±0.07ms      7.26±0.02ms     1.30  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 20}, {'x': 2})
+     4.39±0.03ms      5.66±0.05ms     1.29  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 30}, {'x': 2})
+     11.3±0.06ms      14.4±0.04ms     1.28  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 30}, {'x': 1})
+     9.62±0.02ms       12.0±0.1ms     1.25  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 30}, {})
+     4.99±0.02ms      6.03±0.09ms     1.21  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 30}, {})
+        5.61±0ms      6.69±0.04ms     1.19  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {'x': 20}, {'x': 1})
+     9.58±0.04ms      11.3±0.01ms     1.18  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 20}, {'x': 1})
+     12.8±0.05ms      15.1±0.06ms     1.18  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 30}, {'x': 2})
+     4.39±0.04ms      5.18±0.06ms     1.18  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 30}, {'x': 1})
+      16.3±0.3ms      19.1±0.08ms     1.18  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 20}, {'x': 2})
+     5.26±0.04ms      6.13±0.01ms     1.17  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {}, {})
+     5.78±0.03ms      6.62±0.05ms     1.14  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {}, {'x': 1})
+      9.88±0.1ms      11.3±0.04ms     1.14  benchmarks.Generator.time_batch_input({'x': 5}, {'x': 30}, {'x': 1})
+     6.32±0.04ms      7.21±0.03ms     1.14  benchmarks.Generator.time_batch_input({'x': 10, 'y': 5}, {}, {'x': 2})
+      32.5±0.5ms       37.0±0.3ms     1.14  benchmarks.Accessor.time_accessor_input_dim({'x': 4, 'y': 2})
+     10.6±0.07ms      12.1±0.08ms     1.14  benchmarks.Generator.time_batch_concat({'x': 5, 'y': 5}, False)
+     16.6±0.05ms      18.9±0.07ms     1.14  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {}, {'x': 2})
+     5.04±0.06ms      5.70±0.03ms     1.13  benchmarks.Generator.time_batch_input({'x': 10}, {'x': 20}, {'x': 2})
+      65.0±0.3ms       73.5±0.2ms     1.13  benchmarks.Accessor.time_accessor_input_dim({'x': 2, 'y': 2})
+      10.6±0.1ms      12.0±0.04ms     1.12  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {}, {})
+      12.8±0.1ms      14.1±0.03ms     1.10  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {}, {'x': 1})
+      11.0±0.2ms       12.1±0.2ms     1.10  benchmarks.Generator.time_batch_input({'x': 5, 'y': 5}, {'x': 20}, {})

SOME BENCHMARKS HAVE CHANGED SIGNIFICANTLY.
PERFORMANCE DECREASED.

To-Do:

Note - One test fails (marked with xfail) due to #126 which will be addressed separately

Fixes #131
Fixes #121
Fixes #30

@codecov-commenter
Copy link

codecov-commenter commented Dec 16, 2022

Codecov Report

Merging #132 (95e81b6) into main (13c69fe) will decrease coverage by 1.14%.
The diff coverage is 95.23%.

@@            Coverage Diff             @@
##             main     #132      +/-   ##
==========================================
- Coverage   98.37%   97.23%   -1.15%     
==========================================
  Files           6        6              
  Lines         246      325      +79     
  Branches       49       68      +19     
==========================================
+ Hits          242      316      +74     
- Misses          4        6       +2     
- Partials        0        3       +3     
Impacted Files Coverage Δ
xbatcher/testing.py 100.00% <ø> (ø)
xbatcher/generators.py 96.96% <95.23%> (-3.04%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@maxrjones maxrjones changed the title WIP: Fix regression for batch_dims parameter New BatchSchema class to generate patches and combine into batches Dec 19, 2022
@maxrjones maxrjones changed the title New BatchSchema class to generate patches and combine into batches New BatchSchema class to generate patch selectors and combine into batch selectors Dec 19, 2022
@maxrjones maxrjones changed the title New BatchSchema class to generate patch selectors and combine into batch selectors WIP: New BatchSchema class to generate patch selectors and combine into batch selectors Dec 19, 2022
@maxrjones maxrjones changed the title WIP: New BatchSchema class to generate patch selectors and combine into batch selectors New BatchSchema class to generate patch selectors and combine into batch selectors Dec 20, 2022
@maxrjones
Copy link
Member Author

@jhamman this xbatcher refactor that we discussed over many meetings is now complete. It would be fantastic to hear any thoughts on the PR, but I also recognize that you are busy and on your own time now. Thanks for your brainstorming on these issues!

Copy link
Contributor

@norlandrhagen norlandrhagen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice work here @maxrjones! Super detailed examples. I don't have any meaningful comments though.

pytest.param(
False,
marks=pytest.mark.xfail(
reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/126"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't know you can add xfail within a parameterized test, very cool!

xbatcher/generators.py Outdated Show resolved Hide resolved
@andersy005
Copy link
Member

@maxrjones, this looks great overall! i really like the graphics in #132 (comment). Having some of these graphics included in the docs would be awesome!

maxrjones and others added 2 commits December 30, 2022 21:51
Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com>
@maxrjones maxrjones merged commit 221dfb0 into main Jan 3, 2023
@maxrjones maxrjones deleted the batch-dims-bug branch January 3, 2023 21:13
@maxrjones maxrjones added enhancement New feature or request bug Something isn't working labels Jan 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
4 participants