In [4]:
import os
import math

import torch
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image, ImageDraw

from scorefield.models.ddpm.denoising_diffusion import Unet
from scorefield.models.ddpm.denoising_diffusion_1d import Unet1D, GaussianDiffusion1D, Dataset1D, Trainer1D
from scorefield.models.ddpm.gaussian_diffusion import Diffusion
# from scorefield.utils.rendering import Maze2dRenderer
from scorefield.utils.rl_utils import load_config
from scorefield.utils.utils import log_num_check, imshow, gen_goals, random_batch, eval_batch, prepare_input, search_limit
from scorefield.utils.diffusion_utils import bilinear_interpolate


In [5]:
# Args
config_dir = "./scorefield/configs/diffusion.yaml"
args = load_config(config_dir)
device = args['device']

model_path = os.path.join(args['log_path'], args['model_path'])

map_img = Image.open("map.png")

In [3]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Unet2D(Unet):
    def __init__(
        self, 
        dim, 
        out_dim, 
        dim_mults=(1, 2, 4, 8)
    ):
        super().__init__(dim=dim, out_dim=out_dim, dim_mults=dim_mults)

    def forward(self, obs, x_t, t):
        score_map = super().forward(obs, t)
        score = bilinear_interpolate(score_map, x_t)    # output: (B,2)
        return score

img_size = args['image_size']
noise_steps = args['noise_steps']
train_lr = args['train_lr']
beta = args['beta']
    
model = Unet2D(
    dim=img_size,
    out_dim = 2,
    dim_mults = (1, 2, 4, 8),
).to(device)

diffusion = Diffusion(
    input_size = (2,), 
    noise_steps= noise_steps,
    device=device,
)

optim = torch.optim.Adam(params=model.parameters(), lr=train_lr)

In [None]:
epochs = args['epochs']
batch_size = args['batch_size']

for iters in range(epochs):
    optim.zero_grad()
    
    x0 = (torch.rand(batch_size, 2, device=device, dtype=torch.float32)*2 -1.) *0.1
    obs = prepare_input(args, map_img, goal_pos=x0, circle_rad=2)
    t = diffusion.sample_timesteps(batch_size).to(device)
    
    x_noisy, noise = diffusion.forward_diffusion(x0, t)
    noise_pred = model(obs, x_noisy, t)
    loss =  F.l1_loss(noise, noise_pred)
    loss.backward()
    optim.step()
    
    if iters % 100 == 0:
        print(f"iter {iters}: {loss.item()}")
    
    

iter 0: 987869.75
iter 100: 79110816.0
iter 200: 5173071.0
iter 300: 2677835.0
iter 400: 4500973.0
iter 500: 2762461.25
iter 600: 2260131.0
iter 700: 2583560.5
iter 800: 3016361.5
iter 900: 1603401.25
iter 1000: 2751329.0
iter 1100: 1565580.0
iter 1200: 3748703.75
iter 1300: 3752191.5
iter 1400: 2670843.5
iter 1500: 2140690.0


In [None]:
torch.save(model.state_dict(), "./logs/pretrained/denoising.pt")

In [6]:
model = model.eval()

eval_batch_size = args['eval_batch_size']
init_state = args['init_state']

obs = eval_batch(renderer, map_img, init_state, eval_batch_size, device=device)
obs = torch.tensor(obs, dtype=torch.float32).to(device)
x = torch.tensor(init_state).to(device)

dt = 0.01
trajectory = [x]

for t in tqdm(range(noise_steps)):
    with torch.no_grad():
        score = model(obs, x, noise_steps-t)
    z_t = torch.randn_like(x)
        
    x = x + score * dt / 2. + math.sqrt(dt) * z_t
    trajectory.append(x)
trajectory = torch.tensor(trajectory)

KeyError: 'eval_batch_size'