Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
wingechr committed Jun 15, 2023
2 parents 8453ff0 + 816ea28 commit 168370a
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,6 @@ dmypy.json

# VSCODE
.vscode


example_*
11 changes: 7 additions & 4 deletions data_disaggregation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from .classes import VT, F, T, V, VT_NumericExt
from .utils import (
as_set,
group_idx_first,
group_idx_second,
is_map,
Expand Down Expand Up @@ -118,9 +119,11 @@ def transform(
assert is_mapping(data)
assert is_unique(data)

assert is_subset(
data, size_in
), "Variable index is not a subset of input dimension subset"
if not is_subset(data, size_in):
err = as_set(data) - as_set(size_in)
raise Exception(
f"Variable index is not a subset of input dimension subset: {err}"
)

# validate map
assert is_map(weight_map)
Expand All @@ -138,7 +141,7 @@ def transform(
# weights sum
sumw = sum(w for _, w in vws)
# TODO drop test
sumw <= size_out[t]
assert sumw <= size_out[t]

# drop result?
if threshold:
Expand Down
28 changes: 28 additions & 0 deletions data_disaggregation/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,36 @@ def get_idx_out(
return Index(to_levels_values[0], name=to_levels_names[0])


def as_index(x) -> Index:
if isinstance(x, Index):
return x
elif isinstance(x, Series):
return x.index

raise TypeError(f"Must be of type Index instead of {type(x).__name__}")


def as_series(x) -> Series:
if isinstance(x, Series):
return x
elif isinstance(x, Index):
return Series(1.0, index=x)

raise TypeError("Must be of type Series")


def create_weight_map(
weights: Series,
idx_in: Index,
idx_out: Index = None,
) -> Mapping[Tuple[F, T], float]:
weights = as_series(weights)
idx_in = as_index(idx_in)

if idx_out is None:
idx_out = get_idx_out(weights, idx_in)
# else:
# idx_out = as_index(idx_out)

map_levels = get_dimension_levels(weights)
map_is_multindex = is_multindex(weights)
Expand Down Expand Up @@ -123,6 +146,11 @@ def get_key(row, indices, is_multindex):
key_in = get_key(row, from_level_idcs, from_is_multindex)
key_out = get_key(row, to_level_idcs, to_is_multindex)

if key_in not in idx_in:
continue
if key_out not in idx_out:
continue

key = (key_in, key_out)
result[key] = val_map

Expand Down
4 changes: 4 additions & 0 deletions data_disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def iter_values(x):
yield x[k]


def as_set(x) -> set:
return set(as_list(x))


def as_list(x) -> List:
# meaning: is index
if is_list(x):
Expand Down
13 changes: 5 additions & 8 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,19 @@ def test_align_map_todo(self):
d3 = Index([4], name="d3")
d23 = MultiIndex.from_product([d2, d3])

res = create_weight_map(Series(1, index=d12), d1, d2)
res = create_weight_map(Series(1, index=d12), idx_in=d1, idx_out=d2)
self.assertEqual(res[(1, 2)], 1)

res = create_weight_map(Series(1, index=d12), d1m, d2)
res = create_weight_map(Series(1, index=d12), idx_in=d1m, idx_out=d2)
self.assertEqual(res[((1,), 2)], 1)

res = create_weight_map(Series(1, index=d12), d1m, d2m)
res = create_weight_map(Series(1, index=d12), idx_in=d1m, idx_out=d2m)
self.assertEqual(res[((1,), (2,))], 1)

res = create_weight_map(Series(1, index=d23), d12, d23)
res = create_weight_map(Series(1, index=d23), idx_in=d12, idx_out=d23)
self.assertEqual(res[((1, 2), (2, 4))], 1)

res = create_weight_map(Series(1, index=d1), 0, d1)
self.assertEqual(res[(SCALAR_INDEX_KEY, 1)], 1)

res = create_weight_map(Series(1, index=d1), d1, d0)
res = create_weight_map(Series(1, index=d1), idx_in=d1, idx_out=d0)
self.assertEqual(res[(1, SCALAR_INDEX_KEY)], 1)

def test_is_scalar(self):
Expand Down

0 comments on commit 168370a

Please sign in to comment.