# DWT Image for MIMO-UNet

## dwt with 3 levels

In [50]:
import torch
from pytorch_wavelets import DWTForward, DWTInverse
from einops import rearrange, reduce, repeat
from PIL import Image

from data import valid_dataloader

In [51]:
dataloader = valid_dataloader(subset='Hday2night')
data = next(iter(dataloader))
comp, real, mask = data
comp.shape, comp.dtype, mask.shape, mask.dtype, mask.max(), mask.min()

(torch.Size([64, 3, 256, 256]),
 torch.float32,
 torch.Size([64, 1, 256, 256]),
 torch.float32,
 tensor(1.),
 tensor(0.))

## show the dataloader

In [None]:
def f2i(image):
    # image = 0.5 * image + 0.5
    # float -> uint8
    image = 255 * torch.clip(image, 0, 1.0)
    image = image.to(torch.uint8)
    return image

In [None]:
ms = repeat(mask, 'b c h w -> b (n c) h w', n = 3)
cs = 0.5 * comp + 0.5
rs = 0.5 * real + 0.5
img = torch.concat([cs, rs, ms], dim=3)
img_show = f2i(img)
img_show = rearrange(img_show, 'b c h w -> (b h) w c')
img_show = Image.fromarray(img_show.numpy())
img_show

## DWT 

In [None]:
level = 2
wavelet = 'haar'
# 构建小波变换函数
dwt = DWTForward(J=level, wave=wavelet, mode='zero')
iwt = DWTInverse(wave=wavelet, mode='zero')
yc = dwt(comp)
yr = dwt(real)
ym = dwt(mask)
ycl, ych = yc
yrl, yrh = yr
yml, ymh = ym

ycl.shape, ych[0].shape, yml.shape, ymh[0].shape

In [None]:
ycl.max(), ycl.min()

In [None]:
## low frequency show
ms = repeat(yml, 'b c h w -> b (n c) h w', n = 3)
# cs = 0.5 * ycl + 0.5
cs = (ycl - ycl.min()) / (ycl.max() - ycl.min())
rs = (yrl - yrl.min()) / (yrl.max() - yrl.min())
# rs = 0.5 * yrl + 0.5
img = torch.concat([cs, rs, ms], dim=3)
img_show = f2i(img)
img_show = rearrange(img_show, 'b c h w -> (b h) w c')
img_show = Image.fromarray(img_show.numpy())
img_show

In [None]:
ychh = ych[0]
yrhh = yrh[0]
ymhh = ymh[0]
ychh.shape, ymhh.shape

In [None]:
ychh.max(), ychh.min()

In [None]:
## low frequency show
ms = repeat(ymhh, 'b c d h w -> b (n c) d h w', n = 3)
cs = 0.5 * ychh + 0.5
rs = 0.5 * yrhh + 0.5
img = torch.concat([cs, rs, ms], dim=3)
img_show = f2i(img)
img_show = rearrange(img_show, 'b c h w -> (b h) w c')
img_show = Image.fromarray(img_show.numpy())
img_show