In [1]:
import os
import math

import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageDraw

from scorefield.models.ddpm.denoising_diffusion import Unet
from scorefield.models.ddpm.gaussian_diffusion import Diffusion
from scorefield.models.ddpm.heat_diffusion import HeatDiffusion
from scorefield.utils.rl_utils import load_config
from scorefield.utils.utils import (
    gen_goals, get_url_image, 
    get_url_pretrained, overlay_image, overlay_goal_agent, 
    overlay_images, overlay_goals_agent, log_num_check,
    get_distance,
)
from scorefield.utils.diffusion_utils import bilinear_interpolate, bilinear_interpolate_batch

import matplotlib.pyplot as plt
import itertools
import random


ModuleNotFoundError: No module named 'scorefield.models.ddpm.heat_diffusion'

In [2]:
# Args
config_dir = "./scorefield/configs/diffusion.yaml"
args = load_config(config_dir)
device = args['device']

model_path = os.path.join(args['log_path'], args['model_path'])

# map_img = Image.open("map.png")
bg = Image.open('assets/toy_exp/background0.png')
wastes = []
wastes.append(Image.open('assets/toy_exp/waste0.png'))
wastes.append(Image.open('assets/toy_exp/waste4.png'))
wastes.append(Image.open('assets/toy_exp/waste5.png'))

epochs = args['epochs']
batch_size = args['batch_size']
goal_num = args['goal_num']
goal_bounds = args['goal_bounds']
agent_bounds = args['agent_bounds']
eval_samples = args['eval_samples']

NameError: name 'load_config' is not defined

In [3]:
class Unet2D(Unet):
    def __init__(
        self, 
        dim, 
        out_dim, 
        dim_mults=(1, 2, 4, 8)
    ):
        super().__init__(dim=dim, out_dim=out_dim, dim_mults=dim_mults)

    def forward(self, obs, x_t, t):
        score_map = super().forward(obs, t)
        score = bilinear_interpolate_batch(score_map, x_t)    # output: (B,2)
        return score

img_size = args['image_size']
noise_steps = args['noise_steps']
train_lr = args['train_lr']
beta_start = args['beta_start']
beta_end = args['beta_end']
    
model = Unet2D(
    dim=img_size,
    out_dim = 2,
    dim_mults = (1, 2, 4, 8),
).to(device)

diffusion = HeatDiffusion(
    input_size = (2,), 
    noise_steps= noise_steps,
    device=device,
    beta_start=beta_start,
    beta_end=beta_end,
)

optim = torch.optim.Adam(params=model.parameters(), lr=train_lr)

In [None]:
# Train with single goal

for iters in tqdm(range(goal_num * epochs)):    
    goals = gen_goals(goal_bounds, goal_num)
    expanded_goals = goals.unsqueeze(1).expand(-1, n, -1)
    optim.zero_grad()
    
    random_offsets = (torch.rand(*expanded_goals.shape, device=goals.device, dtype=goals.dtype) * 2 - 1.) * 0.1
    x0 = (expanded_goals + random_offsets).view(-1,2)
    obs = overlay_image(bg, img_size, wastes, x0)
    t = diffusion.sample_timesteps(batch_size).to(device)
    
    x_noisy, noise = diffusion.forward_diffusion(x0, t)
    noise_pred = model(obs, x_noisy, t)
    loss =  F.l1_loss(noise, noise_pred)
#     loss = F.mse_loss(noise, noise_pred)
    loss.backward()
    optim.step()
    
    if iters % 500 == 0:
        print(f"iter {iters}: {loss.item()}")


In [None]:
# Train with single + multiple goals

train_goal_num = goal_num * (2**(goal_num - 1)) # if 3 goals: 12
train_goals = [len(list(itertools.combinations([i for i in range(goal_num)], i+1))* (i+1)) \
               for i in range(goal_num)] # if 3 goals: [3, 6, 3]
train_comb = [len(list(itertools.combinations([i for i in range(goal_num)], i+1))) \
              for i in range(goal_num)] # just a list of combinations of goal num. if 3: [3, 3, 1]

assert batch_size % train_goal_num == 0, 'batch size has to be divided by the goal number'
n = batch_size // train_goal_num  # if 3 goals: batch -> n * 12 

for iters in tqdm(range(train_goal_num * epochs)):
    gs = random.randrange(goal_num)        
    goals = gen_goals(goal_bounds, train_goal_num)
    
    expanded_goals = goals.unsqueeze(1).expand(-1, n, -1)
    optim.zero_grad()
    
    random_offsets = (torch.rand(*expanded_goals.shape, device=goals.device, dtype=goals.dtype) * 2 - 1.) * 0.1
    goal_pos = (expanded_goals + random_offsets).view(-1,2)
    obs = []
    st = 0
    for i, t in enumerate(train_goals):
        if i == 0: 
            obs.append(overlay_image(bg, img_size, wastes, goal_pos[st:st + n * t]))
        else:
            obs.append(overlay_images(bg, img_size, wastes, goal_pos[st:st + n * t], i+1))
        st = st + n * t
    x0 = goal_pos.unsqueeze(1)
    obs = torch.cat(obs, dim=0)
    
    # duplicate obs to match the size with x0
    duplicates = []
    for i, tr in enumerate(train_comb):
        if i > 0: 
            duplicates.append(obs[n*prev_tr:n*(prev_tr + tr)].repeat(i,1,1,1))
            tr = prev_tr + tr
        prev_tr = tr

    # TODO: Below new_obs is only capable when the goal number is 3.
    comb_sum = [0] + [n*(train_comb[c]+train_comb[c+1]) for c in range(len(train_comb)-1)]
    new_obs = torch.cat([obs[comb_sum[n//2]:comb_sum[n//2+1]] if n%2 == 0 else duplicates[n-1] \
                         for n in range(len(comb_sum)-1)*2], dim=0)
#     new_obs = torch.cat([
#         obs[:n*(train_comb[0]+train_comb[1])],
#         duplicates[0],
#         obs[n*(train_comb[0]+train_comb[1]):n*(train_comb[0]+train_comb[1]+train_comb[2])],
#         duplicates[1],
#         obs[14:]
#     ], dim=0)
        
    t = diffusion.sample_timesteps(batch_size).to(device)
    
    x_noisy, noise = diffusion.forward_diffusion(x0, t)
    noise_pred = model(new_obs, x_noisy, t)
    loss =  F.l1_loss(noise, noise_pred)

    loss.backward()
    optim.step()
    
    if iters % (epochs//10) == 0:
        print(f"iter {iters}: {loss.item()}")


In [None]:
from scorefield.utils.utils import log_num_check, get_url_pretrained

model_pth = f"./logs/pretrained/diverse_goalpos_{noise_steps}_{beta_start}~{beta_end}.pt"
model_pth = log_num_check(model_pth)
torch.save(model.state_dict(), model_pth)

# url = 'https://drive.google.com/uc?export=download&id=1CtqczM5cry7wg4poiCv_UeKfIrLtrx9V'
# get_url_pretrained(url, 'model.pt')
# model.load_state_dict(torch.load(f'model.pt'))
# model.load_state_dict(torch.load(f'./logs/pretrained/diverse_goalpos_{noise_steps}_{beta_start}~{beta_end}.pt'))

In [None]:
eval_samples=1000
new_goals = gen_goals([-.8,.8,-.8,.8], goal_num)       # (1, goal_num, 2)
obs_T = overlay_image(bg, img_size, wastes, new_goals) # (goal_num, 3, H, W)

c = 10
fig,axs = plt.subplots(goal_num, c + 1, figsize=(20,20))
T = diffusion.noise_steps
step_size = int(T / c)
dot_size = 2

model.eval()
# x_trace = []
ims = []
with torch.no_grad():
    # x_T = torch.tensor([[0., 0.]], device=device, dtype=torch.float32)
    x_T = gen_goals(agent_bounds, n=(goal_num, eval_samples))
    x = x_T    # (eval_samples, 2)

    for i in tqdm(reversed(range(1, noise_steps)), position=0):
        bkg = bg.copy()
        im = overlay_goal_agent(bkg, wastes, new_goals.cpu(), x.cpu(), dot_size)
        ims.append(im)
#         x_trace.append(x)

        if (T-1-i) % step_size == 0:
            k = (T-1-i) // step_size
            backg = bg.copy()
            img_sample = overlay_goal_agent(backg, wastes, new_goals.cpu(), x.cpu(), dot_size)
            for s in range(len(img_sample)):
                axs[s,k].imshow(img_sample[s])
                axs[0,k].set_title(f't = {T-1-i}')
                axs[s,k].axis('off')

        t = (torch.ones(1) * i).long().to(device)
        while True:
            predicted_noise = model(obs_T, x, t)

            alpha = diffusion.alpha[t]
            alpha_hat = diffusion.alpha_hat[t]
            beta = diffusion.beta[t]
            if i > 1:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)

            x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) \
                        * predicted_noise) + torch.sqrt(beta) * noise
            mask = (x.abs() > 1)
            x[mask] = torch.clamp(x[mask], min=-.9, max=.9)
            if (x.abs() <=1).all():
                break

    backg = bg.copy()
    img_sample = overlay_goal_agent(backg, wastes, new_goals.cpu(),x.cpu(), dot_size)
    for s in range(len(img_sample)):
        axs[s,-1].imshow(img_sample[s])
        axs[0,-1].set_title(f't = {T}')
        axs[s,-1].axis('off')

    bkg = bg.copy()
    im = overlay_goal_agent(bkg, wastes, new_goals.cpu(), x.cpu(), dot_size)       
    ims.append(im)
#     x_trace.append(x)

In [None]:
np.save('./results/new_goal_video/img_list1.npy', ims)

In [None]:
c = 10
fig,axs = plt.subplots(1, c + 1, figsize=(20,20))
axs = axs.flatten()
T = diffusion.noise_steps
step_size = int(T / c)
dot_size = 2

goals = gen_goals(goal_bounds, goal_num)

objs = wastes.copy()
gs = goals.clone()

obs_T = overlay_images(bg, img_size, objs, gs)

imgs = []

model.eval()
with torch.no_grad():
    # x_T = torch.tensor([[0.5, 0.5]], device=device, dtype=torch.float32)
    x_T = gen_goals(agent_bounds, n=(1, eval_samples)).unsqueeze(0)
    x = x_T
    
    imgs=[]
    for i in tqdm(reversed(range(1, noise_steps)), position=0):
        if (T-1-i) % step_size == 0:
            k = (T-1-i) // step_size
            img_sample = overlay_goals_agent(bg, objs, gs.cpu(), x.cpu(), dot_size)
            axs[k].imshow(img_sample)
            axs[k].set_title(f't = {T-1-i}')
            axs[k].axis('off')
            imgs.append(img_sample)

        t = (torch.ones(1) * i).long().to(device)
        while True:
            predicted_noise = model(obs_T, x, t)
            alpha = diffusion.alpha[t]
            alpha_hat = diffusion.alpha_hat[t]
            beta = diffusion.beta[t]
            if i > 1:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) \
                        * predicted_noise) + torch.sqrt(beta) * noise
            # if (abs(x[0][0]) <= 1.) & (abs(x[0][1]) <= 1.):
            #     break
            mask = (x.abs() > 1)
            x[mask] = torch.clamp(x[mask], min=-.9, max=.9)
            if (x.abs() <=1).all():
                break

#         exclude_idx = -1
#         for i in range(len(gs)):
#             if get_distance(x[0], gs[i]) < 0.1 and len(gs) > 1:
#                 exclude_idx = i
#                 break
#         if exclude_idx > -1:
#             objs = objs[:i] + objs[i+1:]
#             gs = torch.cat([gs[:i], gs[i+1:]], dim=0)
#             obs_T = overlay_images(bg, img_size, objs, gs)

    img_sample = overlay_goals_agent(bg, objs, gs.cpu(),x.cpu(), dot_size)
    axs[-1].imshow(img_sample)
    axs[-1].set_title(f't = {T}')
    axs[-1].axis('off')
    imgs.append(img_sample)

In [None]:
np.save('./results/new_goals_video/img_list.npy', imgs)