In [1]:
import os
from datetime import datetime

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageDraw

from scorefield.models.ddpm.denoising_diffusion import Unet
from scorefield.models.rrt.rrt import RRTStar
from scorefield.utils.rl_utils import load_config
from scorefield.utils.utils import (
    gen_goals, gen_agents, overlay_goal, overlay_multiple, combine_objects, overlay_images,
    overlay_goal_agent, overlay_goals_agent, log_num_check,
    draw_obstacles_pil, convert_to_obstacle_masks,
    randgen_obstacle_masks, draw_obstacles_pixel,
    vector_field, clip_vectors
)
from scorefield.utils.diffusion_utils import bilinear_interpolate, bilinear_interpolate_samples

import matplotlib.pyplot as plt
import itertools
from typing import Optional
import shutil


# Args
config_dir = "./scorefield/configs/rrt.yaml"
args = load_config(config_dir)
device = args['device']

bg = Image.open('assets/toy_exp/background0.png')
wastes = []
wastes.append(Image.open('assets/toy_exp/waste0.png'))
# wastes.append(Image.open('assets/toy_exp/waste4.png'))
# wastes.append(Image.open('assets/toy_exp/waste5.png'))

img_size = args['image_size']
delta_dist = args['delta_dist']
radius = args['radius']
time_steps = args['time_steps']
max_iters = args['max_iters']
sample_num = args['sample_num']

goal_bounds = args['goal_bounds']
agent_bounds = args['agent_bounds']
obstacle_pos = args['obstacles']

iterations = args['iterations']
train_lr = args['train_lr']
batch_size = args['batch_size']

class Unet2D(Unet):
    def __init__(
        self, 
        dim, 
        out_dim, 
        dim_mults=(1, 2, 4, 8),
    ):
        super().__init__(dim=dim, out_dim=out_dim, dim_mults=dim_mults)

    def forward(self, obs, t, x_t:Optional[torch.Tensor]=None):
        score_map = super().forward(obs, t)
        score = bilinear_interpolate_samples(score_map, x_t)    # output: (B,2)
        return score

    
model = Unet2D(
    dim=img_size,
    out_dim = 2,
    dim_mults = (1, 2, 4, 8),
).to(device)

rrt = RRTStar(
    image_size=img_size,
    time_steps=time_steps,
    delta_dist=delta_dist,
    radius=radius,
    device=device,
)

optim = torch.optim.Adam(params=model.parameters(), lr=train_lr)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Train with random single goal

current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join('./runs/baselines/', current_time)
writer = SummaryWriter(log_dir)
shutil.copy('./scorefield/configs/rrt.yaml', writer.log_dir)

assert batch_size % 2 == 0

# background = draw_obstacles_pil(bg, obstacle_pos)
# obstacle_masks = convert_to_obstacle_masks(batch_size, background[0].size, img_size, obstacle_pos)

for iters in tqdm(range(iterations)):
    obstacle_masks = randgen_obstacle_masks(batch_size, img_size)
    background = draw_obstacles_pixel(bg, obstacle_masks)
    multi_goals = gen_goals(goal_bounds, (2,batch_size//2), img_size, obstacles=obstacle_masks[:len(obstacle_masks)//2])
    multi_goal = (torch.rand(batch_size//2, 2, 2, device=device, dtype=torch.float32) * 0.2 - 0.1) * 0.1 + multi_goals
    single_goals = gen_goals(goal_bounds, batch_size//2, img_size, obstacles=obstacle_masks[len(obstacle_masks)//2:])
    single_goal = (torch.rand(batch_size//2, 1, 2, device=device, dtype=torch.float32) * 0.2 - 0.1) * 0.1 + single_goals
    
    obs1 = overlay_goal(background, img_size, wastes, single_goal)
    obs2 = overlay_goal(background, img_size, wastes, multi_goal)
    obs = torch.cat((obs1, obs2), dim=0)
    goal = torch.cat((multi_goal[:,0].unsqueeze(1), single_goal), dim=0)
    
    optim.zero_grad()
    
    losses = []
    pred_deltas = []
    deltas = []
    
    t = torch.ones(1).long().to(device)
    # for _ in range(time_steps):
    with torch.no_grad():
        initials = gen_agents(goal_bounds, (sample_num*time_steps, batch_size), img_size, obstacles=obstacle_masks)
        paths, delta = rrt.plan(initials, goal, obstacle_masks)
    pred_delta = model(obs, t, paths[:,:,0])
    
    squared_diff = torch.sum((pred_delta - delta[:,:,0])**2, dim=-1)
    losses.append(torch.mean(squared_diff))
    # pred_deltas.append(pred_delta.reshape(-1,2))
    # deltas.append(delta[:,:,0].reshape(-1,2))
        
    loss = sum(losses)/len(losses)
    loss.backward()
    optim.step()

    deltas = delta[:,:,0].reshape(-1,2) #torch.stack(deltas).reshape(-1,2)
    pred_deltas = pred_delta.reshape(-1,2) #torch.stack(pred_deltas).reshape(-1,2)
    
    deltas_norm = torch.norm(deltas.detach(), dim=-1)
    pred_deltas_norm = torch.norm(pred_deltas.detach(), dim=-1)
    
    dotprod = torch.sum(deltas.detach() * pred_deltas.detach(), dim=-1)
    cosine_sim = dotprod / (torch.norm(deltas.detach(), dim=-1) * torch.norm(pred_deltas.detach(), dim=-1)+1e-8)
    
    writer.add_scalar("Loss/train", loss.item(), iters)
    writer.add_scalar("Norm/target_score", torch.mean(deltas_norm).item(), iters)
    writer.add_scalar("Norm/pred_score", torch.mean(pred_deltas_norm).item(), iters)
    writer.add_scalar("Similarity/dot_product", torch.mean(dotprod).item(), iters)
    writer.add_scalar("Similarity/cosine_similarity", torch.mean(cosine_sim).item(), iters)

    if iters == iterations // 2:
        torch.save(model.state_dict(), os.path.join(writer.log_dir, 'model_params_half.pt'))

torch.save(model.state_dict(), os.path.join(writer.log_dir, 'model_params.pt'))
            
writer.close()


  0%|          | 0/8000 [00:07<?, ?it/s]


AttributeError: 'list' object has no attribute 'reshape'

In [3]:
print(pred_delta.shape, delta[:,:,0].shape)

torch.Size([12, 20, 2]) torch.Size([12, 20, 2])
