In [None]:
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from omegaconf import OmegaConf

# from src.model import Model
from utils.parser import parse_arguments
from models.decomposer import Decomposer
from data.siar_data import SIARDataModule

In [None]:
args = parse_arguments()
config = OmegaConf.load(args.config)
# wandb_logger = WandbLogger(config=config, project="HTCV")

siar = SIARDataModule(config.data.dir, config.train.batch_size)
siar.setup("train", config.train.debug)

model = Decomposer(config=config.model) # --- output: (B, C(768), D(5), H(8), H(8))
### add a new layer to the model to predict gaussian noise (B, 3, 5, 8, 8)


In [22]:
import torch
from diffusers import UNet2DModel, DDPMScheduler

device = "cpu"

diffuser = UNet2DModel(
    sample_size=config.model.sample_size,  # the target image resolution
    in_channels=config.model.input_dim,  # the number of input channels, 3 for RGB images
    out_channels=config.model.output_dim,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(
        64,
        128,
        128,
    ),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",  # a regular ResNet downsampling block
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "UpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
).to(device)
# scheduler = DDPMScheduler()

In [19]:
from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [21]:
num_images = 1

width = 512
height = 512

generator = torch.Generator(device=device)

latents = None
seeds = []
for _ in range(num_images):
    # Get a new random seed, store it and use it as the generator state
    seed = generator.seed()
    seeds.append(seed)
    generator = generator.manual_seed(seed)
    
    image_latents = torch.randn(
        (1, diffuser.in_channels, height // 8, width // 8),
        generator = generator,
        device = device
    )
    latents = image_latents if latents is None else torch.cat((latents, image_latents))
    
# latents should have shape (4, 4, 64, 64) in this case
latents.shape

  (1, diffuser.in_channels, height // 8, width // 8),


torch.Size([1, 3, 64, 64])