Skip to content

Commit

Permalink
simplify the folding.
Browse files Browse the repository at this point in the history
  • Loading branch information
v0lta committed Aug 1, 2023
1 parent f44800e commit 2fc5173
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
17 changes: 9 additions & 8 deletions src/ptwt/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,16 @@ def _pad_symmetric(
def _fold_channels(data: torch.Tensor) -> torch.Tensor:
"""Fold [batch, channel, height width] into [batch*channel, height, widht]."""
ds = data.shape
fold_data = torch.permute(data, [2, 3, 0, 1])
fold_data = torch.reshape(fold_data, [ds[2], ds[3], ds[0] * ds[1]])
return torch.permute(fold_data, [2, 0, 1])
return torch.reshape(
data,
[
ds[0] * ds[1],
ds[2],
ds[3],
],
)


def _unfold_channels(data: torch.Tensor, ds: List[int]) -> torch.Tensor:
"""Unfold [batch*channel, height, widht] into [batch, channel, height, width]."""
unfold_data = torch.permute(data, [1, 2, 0])
unfold_data = torch.reshape(
unfold_data, [data.shape[1], data.shape[2], ds[0], ds[1]]
)
return torch.permute(unfold_data, [2, 3, 0, 1])
return torch.reshape(data, [ds[0], ds[1], data.shape[1], data.shape[2]])
4 changes: 2 additions & 2 deletions tests/test_convolution_fwt_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def _cat_batch_list(batch_lists: List) -> List:
cat_list = batch_list
else:
for pos, (cat_el, batch_el) in enumerate(zip(cat_list, batch_list)):
if type(cat_el) == np.ndarray:
if type(cat_el) is np.ndarray:
cat_list[pos] = np.concatenate([cat_el, batch_el])
elif type(cat_el) == dict:
elif type(cat_el) is dict:
for key, tensor in cat_el.items():
cat_el[key] = np.concatenate([tensor, batch_el[key]])
else:
Expand Down

0 comments on commit 2fc5173

Please sign in to comment.