In [8]:
import torch
from ddpm.ddpm import DDPM
import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
import ddpm.modules_v2 as modules_v2
import ddpm.modules_v3 as modules_v3
from data import SequencesDataset
from train import train
import torchvision.transforms as transforms
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import tqdm
from typing import List

In [9]:
def save_imgs(
    frames_real: torch.Tensor,
    frames_gen: List[torch.Tensor],
    path: str
):
    def get_np_img(tensor: torch.Tensor) -> np.ndarray:
        return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)

    height_row = 5
    col_width = 5
    cols = len(frames_real)
    rows = 1 + len(frames_gen)
    fig, axes = plt.subplots(rows, cols, figsize=(col_width * cols, height_row * rows))
    for row in range(rows):
        frames = frames_real if row == 0 else frames_gen[row - 1]
        for i in range(len(frames_real)):
            axes[row, i].imshow(get_np_img(frames[i]))
            
    plt.subplots_adjust(wspace=0, hspace=0)
    
    # Save the combined figure
    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

In [10]:
EPOCHS = 30

T = 1000
input_channels = 3
context_length = 4
actions_count = 5
batch_size = 1
num_workers = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
# For Mac OS
if torch.backends.mps.is_available():
    device = "mps"
ROOT_PATH = "../"
def local_path(path):
    return os.path.join(ROOT_PATH, path)
MODEL_PATH = local_path("test_models/diffusion/model_14_v2.pth")

In [11]:
ddpm = DDPM(
    T = T,
    eps_model=modules_v2.UNet(
        in_channels=input_channels * (context_length + 1),
        out_channels=3,
        T=T+1,
        actions_count=actions_count,
        seq_length=context_length
    ),
    context_length=context_length,
    device=device
)
ddpm.load_state_dict(torch.load(MODEL_PATH, map_location=device)["model"])

  ddpm.load_state_dict(torch.load(MODEL_PATH, map_location=device)["model"])


<All keys matched successfully>

In [12]:
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])

dataset = SequencesDataset(
    images_dir=local_path("training_data/snapshots"),
    actions_path=local_path("training_data/actions"),
    seq_length=context_length,
    transform=transform_to_tensor
)

In [17]:
length = len(dataset)
length_session = 20
count = 1
pbar = tqdm.tqdm(total=count * length_session)
for i in range(count):
    index = random.randint(0, length - 1)
    img, last_imgs, actions = dataset[index]

    img = img.to(device)
    last_imgs = last_imgs.to(device)
    actions = actions.to(device)

    real_imgs = last_imgs.clone()
    gen_2_imgs = last_imgs.clone()
    gen_10_imgs = last_imgs.clone()
    gen_5_imgs = last_imgs.clone()
    for j in range(1, length_session):
        img, last_imgs, actions = dataset[index + j]
        img = img.to(device)
        last_imgs = last_imgs.to(device)
        actions = actions.to(device)
        real_imgs = torch.concat([real_imgs, img[None, :, :, :]], dim=0)
        gen_img = ddpm.ddim_sample(img.shape, gen_10_imgs[-context_length:].unsqueeze(0), actions.unsqueeze(0), steps=10)[0][:, 2:-2, 2:-2]
        gen_10_imgs = torch.concat([gen_10_imgs, gen_img[None, :, :, :]], dim=0)
        gen_img = ddpm.ddim_sample(img.shape, gen_5_imgs[-context_length:].unsqueeze(0), actions.unsqueeze(0), steps=5)[0][:, 2:-2, 2:-2]
        gen_5_imgs = torch.concat([gen_5_imgs, gen_img[None, :, :, :]], dim=0)
        gen_img = ddpm.ddim_sample(img.shape, gen_2_imgs[-context_length:].unsqueeze(0), actions.unsqueeze(0), steps=2)[0][:, 2:-2, 2:-2]
        gen_2_imgs = torch.concat([gen_2_imgs, gen_img[None, :, :, :]], dim=0)
        pbar.update(1)

    save_imgs(real_imgs, [gen_10_imgs, gen_5_imgs, gen_2_imgs], f"{MODEL_PATH.split('/')[-1]}.png")

 95%|█████████▌| 19/20 [01:03<00:03,  3.35s/it]
 95%|█████████▌| 19/20 [00:32<00:01,  1.70s/it]