In [1]:
import os
from datetime import datetime

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageDraw

from scorefield.models.ddpm.denoising_diffusion import Unet
from scorefield.models.rrt.rrt import RRTStar
from scorefield.utils.rl_utils import load_config
from scorefield.utils.utils import (
    gen_goals, gen_agents, overlay_goal, overlay_multiple, combine_objects, overlay_images,
    overlay_goal_agent, overlay_goals_agent, log_num_check,
    draw_obstacles_pil, convert_to_obstacle_masks,
    randgen_obstacle_masks, draw_obstacles_pixel,
    vector_field, clip_vectors
)
from scorefield.utils.diffusion_utils import bilinear_interpolate, bilinear_interpolate_samples

import matplotlib.pyplot as plt
import itertools
from typing import Optional
import shutil


# Args
config_dir = "./scorefield/configs/rrt.yaml"
args = load_config(config_dir)
device = args['device']

bg = Image.open('assets/toy_exp/background0.png')
wastes = []
wastes.append(Image.open('assets/toy_exp/waste0.png'))
# wastes.append(Image.open('assets/toy_exp/waste4.png'))
# wastes.append(Image.open('assets/toy_exp/waste5.png'))

img_size = args['image_size']
delta_dist = args['delta_dist']
radius = args['radius']
time_steps = args['time_steps']
max_iters = args['max_iters']
sample_num = args['sample_num']

goal_bounds = args['goal_bounds']
agent_bounds = args['agent_bounds']
obstacle_pos = args['obstacles']

iterations = args['iterations']
train_lr = args['train_lr']
batch_size = args['batch_size']

seed = args['seed']

class Unet2D(Unet):
    def __init__(
        self, 
        dim, 
        out_dim, 
        dim_mults=(1, 2, 4, 8),
    ):
        super().__init__(dim=dim, out_dim=out_dim, dim_mults=dim_mults)

    def forward(self, obs, t, x_t:Optional[torch.Tensor]=None):
        score_map = super().forward(obs, t)
        score = bilinear_interpolate_samples(score_map, x_t)    # output: (B,2)
        return score

    
model = Unet2D(
    dim=img_size,
    out_dim = 2,
    dim_mults = (1, 2, 4, 8),
).to(device)

rrt = RRTStar(
    image_size=img_size,
    time_steps=time_steps,
    delta_dist=delta_dist,
    radius=radius,
    random_seed=seed,
)

optim = torch.optim.Adam(params=model.parameters(), lr=train_lr)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Train with random single goal

current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join('./runs/baselines/', current_time)
writer = SummaryWriter(log_dir)
shutil.copy('./scorefield/configs/rrt.yaml', writer.log_dir)

assert batch_size % 2 == 0

# background = draw_obstacles_pil(bg, obstacle_pos)
# obstacle_masks = convert_to_obstacle_masks(batch_size, background[0].size, img_size, obstacle_pos)

for iters in tqdm(range(iterations)):
    if iters % 100 == 0:
        background = draw_obstacles_pil(bg, obstacle_pos)
        background = [background[0] for _ in range(batch_size)]
        obstacle_masks = convert_to_obstacle_masks(batch_size, background[0].size, img_size, obstacle_pos)
    else:
        obstacle_masks = randgen_obstacle_masks(batch_size, img_size)
        background = draw_obstacles_pixel(bg, obstacle_masks)
    multi_goals = gen_goals(goal_bounds, (2,batch_size//2), img_size, obstacles=obstacle_masks[:batch_size//2])
    multi_goal = (torch.rand(batch_size//2, 2, 2, device=device, dtype=torch.float32) * 0.2 - 0.1) * 0.05 + multi_goals
    single_goals = gen_goals(goal_bounds, batch_size//2, img_size, obstacles=obstacle_masks[batch_size//2:])
    single_goal = (torch.rand(batch_size//2, 1, 2, device=device, dtype=torch.float32) * 0.2 - 0.1) * 0.05 + single_goals
    
    obs1 = overlay_goal(background[:batch_size//2], img_size, wastes, multi_goal)
    obs2 = overlay_goal(background[batch_size//2:], img_size, wastes, single_goal)
    
    obs = torch.cat((obs1, obs2), dim=0)
    goal = torch.cat((multi_goal[:,0].unsqueeze(1), single_goal), dim=0)
    
    optim.zero_grad()
    
    losses = []
    pred_deltas = []
    
    t = torch.ones(1).long().to(device)
    
    with torch.no_grad():
        initials = gen_agents(goal_bounds, (sample_num*time_steps, batch_size), img_size, obstacles=obstacle_masks)
        _, delta = rrt.plan(starts=initials, goals=goal, obstacle_masks=obstacle_masks)
        if delta == None:
            continue

    deltas = torch.tensor([[sample[0] for sample in delt] for delt in delta]).to(device)

    pred_delta = model(obs, t, initials)
    
    squared_diff = torch.sum((pred_delta - deltas)**2, dim=-1)
    losses.append(torch.mean(squared_diff))
        
    loss = sum(losses)/len(losses)
    loss.backward()
    optim.step()

    deltas = deltas.reshape(-1,2)
    pred_deltas = pred_delta.reshape(-1,2)
    
    deltas_norm = torch.norm(deltas.detach(), dim=-1)
    pred_deltas_norm = torch.norm(pred_deltas.detach(), dim=-1)
    
    dotprod = torch.sum(deltas.detach() * pred_deltas.detach(), dim=-1)
    cosine_sim = dotprod / (torch.norm(deltas.detach(), dim=-1) * torch.norm(pred_deltas.detach(), dim=-1)+1e-8)
    
    writer.add_scalar("Loss/train", loss.item(), iters)
    writer.add_scalar("Norm/target_score", torch.mean(deltas_norm).item(), iters)
    writer.add_scalar("Norm/pred_score", torch.mean(pred_deltas_norm).item(), iters)
    writer.add_scalar("Similarity/dot_product", torch.mean(dotprod).item(), iters)
    writer.add_scalar("Similarity/cosine_similarity", torch.mean(cosine_sim).item(), iters)

    if iters % (iterations // 10) == 0:
        torch.save(model.state_dict(), os.path.join(writer.log_dir, f'model_params_{iters}.pt'))

torch.save(model.state_dict(), os.path.join(writer.log_dir, 'model_params.pt'))
            
writer.close()


  0%|          | 13/3000 [04:06<14:52:42, 17.93s/it]



  1%|          | 26/3000 [07:58<12:51:22, 15.56s/it]



  1%|▏         | 43/3000 [13:18<14:12:31, 17.30s/it]



  2%|▏         | 61/3000 [19:12<14:28:34, 17.73s/it]



  2%|▏         | 67/3000 [20:59<14:56:50, 18.35s/it]



  3%|▎         | 95/3000 [30:12<16:22:34, 20.29s/it]



  4%|▍         | 131/3000 [41:43<14:11:40, 17.81s/it]



  4%|▍         | 132/3000 [42:02<14:28:11, 18.16s/it]



  5%|▍         | 140/3000 [44:29<13:17:28, 16.73s/it]



  5%|▌         | 158/3000 [50:24<15:15:48, 19.33s/it]



  6%|▌         | 174/3000 [55:27<13:54:10, 17.71s/it]



  6%|▋         | 191/3000 [1:01:05<17:49:21, 22.84s/it]



  8%|▊         | 230/3000 [1:13:14<11:52:48, 15.44s/it]



  8%|▊         | 243/3000 [1:17:35<14:20:33, 18.73s/it]



  8%|▊         | 253/3000 [1:20:50<14:16:42, 18.71s/it]



  9%|▊         | 258/3000 [1:22:12<12:25:11, 16.31s/it]



  9%|▊         | 260/3000 [1:22:46<12:19:42, 16.20s/it]



  9%|▉         | 266/3000 [1:24:44<14:27:29, 19.04s/it]



  9%|▉         | 271/3000 [1:26:13<12:51:11, 16.96s/it]



  9%|▉         | 274/3000 [1:27:24<16:36:56, 21.94s/it]



  9%|▉         | 279/3000 [1:28:55<14:27:10, 19.12s/it]



 11%|█▏        | 341/3000 [1:49:46<14:22:20, 19.46s/it]



 12%|█▏        | 350/3000 [1:52:53<15:35:53, 21.19s/it]



 12%|█▏        | 357/3000 [1:55:41<17:20:12, 23.61s/it]



 12%|█▏        | 373/3000 [2:01:13<16:02:15, 21.98s/it]



 13%|█▎        | 377/3000 [2:02:30<15:01:16, 20.62s/it]



 13%|█▎        | 381/3000 [2:03:48<14:16:28, 19.62s/it]



 13%|█▎        | 393/3000 [2:07:39<12:20:23, 17.04s/it]



 13%|█▎        | 394/3000 [2:07:55<12:05:25, 16.70s/it]



 14%|█▎        | 408/3000 [2:12:57<13:17:42, 18.47s/it]



 15%|█▍        | 438/3000 [2:23:09<13:38:01, 19.16s/it]



 15%|█▍        | 441/3000 [2:24:17<15:35:19, 21.93s/it]



 16%|█▌        | 465/3000 [2:32:32<14:06:54, 20.05s/it]



 16%|█▌        | 477/3000 [2:36:12<13:42:25, 19.56s/it]



 18%|█▊        | 534/3000 [2:56:02<13:02:48, 19.05s/it]



 19%|█▉        | 565/3000 [3:06:10<12:53:52, 19.07s/it]



 20%|█▉        | 590/3000 [3:14:20<11:53:48, 17.77s/it]



 21%|██        | 620/3000 [3:23:47<12:36:58, 19.08s/it]



 22%|██▏       | 651/3000 [3:33:59<10:19:06, 15.81s/it]



 22%|██▏       | 654/3000 [3:34:55<11:51:04, 18.19s/it]



 23%|██▎       | 676/3000 [3:42:22<12:23:32, 19.20s/it]



 24%|██▎       | 708/3000 [3:52:35<11:34:14, 18.17s/it]



 24%|██▍       | 725/3000 [3:57:52<11:14:53, 17.80s/it]



 25%|██▌       | 753/3000 [4:06:33<12:13:15, 19.58s/it]



 25%|██▌       | 757/3000 [4:07:44<10:42:00, 17.17s/it]



 27%|██▋       | 812/3000 [4:27:05<13:01:28, 21.43s/it]



 27%|██▋       | 816/3000 [4:28:46<14:28:39, 23.86s/it]



 28%|██▊       | 838/3000 [4:35:57<9:59:35, 16.64s/it] 



 28%|██▊       | 847/3000 [4:38:54<11:37:09, 19.43s/it]



 29%|██▉       | 863/3000 [4:44:39<12:25:02, 20.92s/it]



 29%|██▉       | 874/3000 [4:48:09<10:01:33, 16.98s/it]



 30%|███       | 915/3000 [5:01:58<12:00:20, 20.73s/it]



 32%|███▏      | 968/3000 [5:19:17<13:04:34, 23.17s/it]



 33%|███▎      | 979/3000 [5:23:11<11:21:24, 20.23s/it]



 33%|███▎      | 990/3000 [5:26:52<10:54:27, 19.54s/it]



 34%|███▍      | 1019/3000 [5:36:37<11:24:24, 20.73s/it]



 34%|███▍      | 1022/3000 [5:37:45<12:12:39, 22.22s/it]



 35%|███▌      | 1055/3000 [5:48:37<10:29:53, 19.43s/it]



 36%|███▌      | 1068/3000 [5:53:00<9:55:38, 18.50s/it] 



 36%|███▌      | 1078/3000 [5:56:28<11:11:43, 20.97s/it]



 36%|███▌      | 1083/3000 [5:58:08<10:18:56, 19.37s/it]



 36%|███▌      | 1085/3000 [5:58:44<9:38:38, 18.13s/it] 



 36%|███▋      | 1090/3000 [6:00:36<11:59:53, 22.61s/it]



 38%|███▊      | 1125/3000 [6:13:12<13:18:06, 25.54s/it]



 38%|███▊      | 1136/3000 [6:16:54<9:51:35, 19.04s/it] 



 41%|████      | 1217/3000 [6:43:02<9:43:22, 19.63s/it] 



 41%|████      | 1221/3000 [6:44:15<9:08:53, 18.51s/it]



 41%|████      | 1235/3000 [6:49:02<8:53:57, 18.15s/it] 



 42%|████▏     | 1263/3000 [6:58:57<9:51:48, 20.44s/it] 



 42%|████▎     | 1275/3000 [7:02:58<10:35:18, 22.10s/it]



 44%|████▍     | 1316/3000 [7:16:27<9:06:57, 19.49s/it] 



 44%|████▍     | 1323/3000 [7:18:35<7:56:58, 17.07s/it] 



 44%|████▍     | 1326/3000 [7:19:23<7:34:25, 16.29s/it]



 45%|████▌     | 1357/3000 [7:29:23<7:39:35, 16.78s/it]



 49%|████▉     | 1471/3000 [8:08:22<9:55:12, 23.36s/it] 



 49%|████▉     | 1477/3000 [8:10:32<10:14:00, 24.19s/it]



 50%|████▉     | 1492/3000 [8:15:11<6:58:04, 16.63s/it] 



 51%|█████     | 1520/3000 [8:23:59<6:59:46, 17.02s/it]



 52%|█████▏    | 1552/3000 [8:34:19<8:37:06, 21.43s/it]



 52%|█████▏    | 1574/3000 [8:41:29<7:03:44, 17.83s/it]



 53%|█████▎    | 1585/3000 [8:44:57<7:07:35, 18.13s/it]



 53%|█████▎    | 1590/3000 [8:46:27<6:39:23, 17.00s/it]



 53%|█████▎    | 1599/3000 [8:49:33<7:07:22, 18.30s/it]



 55%|█████▌    | 1655/3000 [9:09:12<7:38:37, 20.46s/it]



 56%|█████▌    | 1672/3000 [9:14:46<6:20:18, 17.18s/it]



 56%|█████▌    | 1687/3000 [9:19:39<6:07:16, 16.78s/it]



 58%|█████▊    | 1736/3000 [9:35:23<5:55:58, 16.90s/it]



 59%|█████▉    | 1766/3000 [9:45:38<7:33:30, 22.05s/it]



 59%|█████▉    | 1767/3000 [9:45:49<6:20:03, 18.49s/it]



 59%|█████▉    | 1776/3000 [9:48:26<5:38:01, 16.57s/it]



 62%|██████▏   | 1863/3000 [10:17:05<5:30:52, 17.46s/it]



 62%|██████▏   | 1866/3000 [10:18:12<6:31:20, 20.71s/it]



 63%|██████▎   | 1881/3000 [10:23:25<7:48:33, 25.12s/it]



 63%|██████▎   | 1896/3000 [10:28:24<5:24:28, 17.63s/it]



 64%|██████▎   | 1912/3000 [10:33:26<5:36:26, 18.55s/it]



 65%|██████▌   | 1953/3000 [10:46:45<4:55:54, 16.96s/it]



 65%|██████▌   | 1961/3000 [10:49:18<4:55:23, 17.06s/it]



 66%|██████▌   | 1974/3000 [10:53:21<5:16:38, 18.52s/it]



 68%|██████▊   | 2037/3000 [11:13:54<5:08:09, 19.20s/it]



 69%|██████▉   | 2067/3000 [11:23:50<5:07:01, 19.74s/it]



 69%|██████▉   | 2081/3000 [11:28:12<4:18:05, 16.85s/it]



 72%|███████▏  | 2153/3000 [11:51:42<4:26:23, 18.87s/it]



 72%|███████▎  | 2175/3000 [11:58:36<3:55:01, 17.09s/it]



 73%|███████▎  | 2204/3000 [12:07:56<3:17:34, 14.89s/it]



 74%|███████▍  | 2219/3000 [12:12:32<3:58:11, 18.30s/it]



 74%|███████▍  | 2223/3000 [12:13:27<3:01:33, 14.02s/it]



 75%|███████▌  | 2251/3000 [12:23:04<4:45:26, 22.87s/it]



 75%|███████▌  | 2258/3000 [12:25:16<3:18:01, 16.01s/it]



 76%|███████▌  | 2269/3000 [12:28:40<3:31:09, 17.33s/it]



 76%|███████▌  | 2273/3000 [12:29:56<3:46:00, 18.65s/it]



 76%|███████▌  | 2278/3000 [12:31:35<4:01:28, 20.07s/it]



 76%|███████▌  | 2282/3000 [12:32:42<3:23:29, 17.01s/it]



 76%|███████▋  | 2289/3000 [12:35:04<3:56:21, 19.95s/it]



 76%|███████▋  | 2292/3000 [12:36:16<4:18:42, 21.92s/it]



 77%|███████▋  | 2302/3000 [12:39:24<3:50:36, 19.82s/it]



 77%|███████▋  | 2311/3000 [12:42:23<3:47:25, 19.80s/it]



 77%|███████▋  | 2317/3000 [12:44:08<3:06:20, 16.37s/it]



 77%|███████▋  | 2318/3000 [12:44:33<3:37:20, 19.12s/it]



 77%|███████▋  | 2324/3000 [12:46:17<3:07:52, 16.68s/it]



 78%|███████▊  | 2325/3000 [12:46:42<3:37:23, 19.32s/it]



 78%|███████▊  | 2338/3000 [12:50:56<3:32:51, 19.29s/it]



 78%|███████▊  | 2343/3000 [12:52:29<3:21:20, 18.39s/it]



 79%|███████▉  | 2368/3000 [13:00:26<3:08:26, 17.89s/it]



 80%|███████▉  | 2390/3000 [13:07:18<3:04:42, 18.17s/it]



 81%|████████  | 2428/3000 [13:20:29<3:48:33, 23.97s/it]



 82%|████████▏ | 2470/3000 [13:34:37<3:40:29, 24.96s/it]



 83%|████████▎ | 2486/3000 [13:39:40<2:51:05, 19.97s/it]



 85%|████████▍ | 2542/3000 [13:58:35<3:07:30, 24.57s/it]



 85%|████████▌ | 2555/3000 [14:02:50<2:02:39, 16.54s/it]



 85%|████████▌ | 2561/3000 [14:04:45<2:17:55, 18.85s/it]



 85%|████████▌ | 2562/3000 [14:05:07<2:23:41, 19.68s/it]



 86%|████████▌ | 2570/3000 [14:07:55<2:21:30, 19.75s/it]



 86%|████████▌ | 2584/3000 [14:12:29<2:01:55, 17.59s/it]



 86%|████████▋ | 2589/3000 [14:14:23<2:37:11, 22.95s/it]



 86%|████████▋ | 2593/3000 [14:15:50<2:39:18, 23.49s/it]



 89%|████████▊ | 2661/3000 [14:38:18<1:39:18, 17.58s/it]



 90%|████████▉ | 2687/3000 [14:46:48<1:33:53, 18.00s/it]



 90%|█████████ | 2715/3000 [14:56:27<1:32:24, 19.45s/it]



 91%|█████████ | 2722/3000 [14:58:45<1:41:12, 21.84s/it]



 91%|█████████▏| 2743/3000 [15:05:47<1:24:32, 19.74s/it]



 92%|█████████▏| 2765/3000 [15:12:46<1:17:43, 19.85s/it]



 94%|█████████▎| 2805/3000 [15:25:52<1:01:48, 19.02s/it]



 94%|█████████▍| 2823/3000 [15:31:22<50:29, 17.12s/it]  



 94%|█████████▍| 2828/3000 [15:32:52<52:02, 18.15s/it]



 94%|█████████▍| 2829/3000 [15:33:04<46:32, 16.33s/it]



 94%|█████████▍| 2831/3000 [15:33:40<46:56, 16.66s/it]



 98%|█████████▊| 2928/3000 [16:05:48<22:31, 18.77s/it]  



 98%|█████████▊| 2930/3000 [16:06:30<23:45, 20.37s/it]



 98%|█████████▊| 2935/3000 [16:08:06<21:45, 20.08s/it]



 99%|█████████▉| 2971/3000 [16:19:50<08:45, 18.13s/it]



100%|█████████▉| 2992/3000 [16:26:30<02:25, 18.15s/it]



100%|██████████| 3000/3000 [16:29:17<00:00, 19.79s/it]
