In [None]:
import torch
from ddpm import DDPM
import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from unet import UNet
from data import SequencesDataset
from train import train
import torchvision.transforms as transforms
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import tqdm

In [None]:
def save_imgs(
    frames_real: torch.Tensor,
    frames_generation: torch.Tensor,
    path: str
):
    def get_np_img(tensor: torch.Tensor) -> np.ndarray:
        return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)
    # new_img = model.sample(real_imgs.shape, previous_frames.unsqueeze(0), previous_actions.unsqueeze(0))
    # new_img = real_imgs.clone().unsqueeze(0)

    height_row = 5
    col_width = 5
    cols = len(frames_real)
    fig, axes = plt.subplots(2, cols, figsize=(col_width * cols, height_row * 2))
    for row in range(2):
        for i in range(len(frames_real)):
            axes[row, i].imshow(get_np_img(frames_real[i]) if row == 0 else get_np_img(frames_generation[i]))
    plt.subplots_adjust(wspace=0, hspace=0)
    
    # Save the combined figure
    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

In [None]:
EPOCHS = 30

T = 1000
input_channels = 3
context_length = 4
actions_count = 5
batch_size = 1
num_workers = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
# For Mac OS
if torch.backends.mps.is_available():
    device = "mps"
ROOT_PATH = "./"
def local_path(path):
    return os.path.join(ROOT_PATH, path)
MODEL_PATH = local_path("model_new.pth")

In [19]:
ddpm = DDPM(
    T = T,
    eps_model=UNet(
        in_channels=input_channels * (context_length + 1),
        out_channels=3,
        T=T+1,
        actions_count=actions_count,
        seq_length=context_length
    ),
    context_length=context_length,
    device=device
)
ddpm.load_state_dict(torch.load(MODEL_PATH, map_location=device))

  ddpm.load_state_dict(torch.load(MODEL_PATH, map_location=device))


<All keys matched successfully>

In [None]:
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])

dataset = SequencesDataset(
    images_dir=local_path("snapshots"),
    actions_path=local_path("actions"),
    seq_length=context_length,
    transform=transform_to_tensor
)

In [None]:
length = len(dataset)
length_session = 10
count = 3
pbar = tqdm.tqdm(total=count * length_session)
for i in range(count):
        index = random.randint(0, length - 1)
        img, last_imgs, actions = dataset[index]

        img = img.to(device)
        last_imgs = last_imgs.to(device)
        actions = actions.to(device)

        real_imgs = last_imgs.clone()
        gen_imgs = last_imgs.clone()
        for j in range(1, length_session):
            img, last_imgs, actions = dataset[index + j]
            img = img.to(device)
            last_imgs = last_imgs.to(device)
            actions = actions.to(device)
            real_imgs = torch.concat([real_imgs, img[None, :, :, :]], dim=0)
            gen_img = ddpm.sample(img.shape, gen_imgs[-context_length:].unsqueeze(0), actions.unsqueeze(0))[0][:, 2:-2, 2:-2]
            gen_imgs = torch.concat([gen_imgs, gen_img[None, :, :, :]], dim=0)
            pbar.update(1)

        save_imgs(real_imgs, gen_imgs, f"{i}.png")