In [1]:
import os
import math

import torch
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image, ImageDraw

from scorefield.models.ddpm.denoising_diffusion import Unet
from scorefield.models.ddpm.denoising_diffusion_1d import Unet1D, GaussianDiffusion1D, Dataset1D, Trainer1D
from scorefield.models.ddpm.gaussian_diffusion import Diffusion
from scorefield.utils.rendering import Maze2dRenderer
from scorefield.utils.rl_utils import load_config
from scorefield.utils.utils import log_num_check, imshow, gen_goals, random_batch, eval_batch, prepare_input
from scorefield.utils.diffusion_utils import bilinear_interpolate


  from .autonotebook import tqdm as notebook_tqdm
No module named 'flow'
No module named 'carla'
pybullet build time: May 20 2022 19:44:17


In [3]:
# Args
sac_config_dir = "./scorefield/configs/diffusion.yaml"
args = load_config(sac_config_dir)
device = args['device']

model_path = os.path.join(args['log_path'], args['model_path'])

map_img = Image.open("map.png")

In [3]:
class Unet2D(Unet):
    def __init__(
        self, 
        dim, 
        out_dim, 
        dim_mults=(1, 2, 4, 8)
    ):
        super().__init__(dim=dim, out_dim=out_dim, dim_mults=dim_mults)
        
    def forward(self, obs, x, t):
        score_map = super().forward(obs, t)
        score = bilinear_interpolate(score_map, x)    # output: (B,2)
        return score

img_size = args['image_size']
noise_steps = args['noise_steps']
train_lr = args['train_lr']
    
model = Unet2D(
    dim=img_size,
    out_dim = 2,
    dim_mults = (1, 2, 4, 8),
).to(device)

diffusion = Diffusion(
    input_size = (2,), 
    noise_steps= noise_steps,
    device=device,
)

optim = torch.optim.Adam(params=model.parameters(), lr=train_lr)

In [4]:
epochs = args['epochs']
batch_size = args['batch_size']

for iters in range(epochs):
    x0 = (torch.rand(batch_size, 2, device=device, dtype=torch.float32)* 2 - 1.) * 0.1
    obs = prepare_input(map_img, goal_pos=x0)
    t = diffusion.sample_timesteps(batch_size).to(device)
    
    x_t = diffusion.noise_goal(x0, t).to(device)
    target_score = -(x_t - x0)  # It has to be divided by sigma^2, but sigma=1 in this case
    predicted_score = model(obs, x_t, t)
        
    loss = F.mse_loss(predicted_score, target_score)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if iters % 100 == 0:
        print(f"iter {iters}: {loss.item()}")
    
    

iter 0: 17.886507034301758
iter 100: 4.396085739135742
iter 200: 7.471299171447754
iter 300: 4.569242477416992
iter 400: 10.439207077026367
iter 500: 4.316651344299316
iter 600: 3.6395556926727295
iter 700: 7.083784103393555
iter 800: 8.623814582824707
iter 900: 5.286238670349121
iter 1000: 5.017495155334473
iter 1100: 7.583033084869385
iter 1200: 8.246065139770508
iter 1300: 7.141334056854248
iter 1400: 5.840270042419434
iter 1500: 3.9974639415740967
iter 1600: 7.456106662750244
iter 1700: 8.055895805358887
iter 1800: 6.145851135253906
iter 1900: 6.427086353302002
iter 2000: 4.475827693939209
iter 2100: 8.126260757446289
iter 2200: 7.245627403259277
iter 2300: 5.594959259033203
iter 2400: 3.522162437438965
iter 2500: 4.685199737548828
iter 2600: 4.983145713806152
iter 2700: 3.5634384155273438
iter 2800: 2.674853563308716
iter 2900: 2.604846477508545
iter 3000: 5.335857391357422
iter 3100: 3.776165723800659
iter 3200: 3.601693630218506
iter 3300: 5.649759292602539
iter 3400: 4.13551139

In [None]:
torch.save(model.state_dict(), "./logs/pretrained/denoising.pt")

In [6]:
model = model.eval()

eval_batch_size = args['eval_batch_size']
init_state = args['init_state']

obs = eval_batch(renderer, map_img, init_state, eval_batch_size, device=device)
obs = torch.tensor(obs, dtype=torch.float32).to(device)
x = torch.tensor(init_state).to(device)

dt = 0.01
trajectory = [x]

for t in tqdm(range(noise_steps)):
    with torch.no_grad():
        score = model(obs, x, noise_steps-t)
    z_t = torch.randn_like(x)   # z_t ~ N(0,I)
        
    x = x + score * dt / 2. + math.sqrt(dt) * z_t
    trajectory.append(x)
trajectory = torch.tensor(trajectory)

KeyError: 'eval_batch_size'