In [2]:
import os
from datetime import datetime

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm import tqdm
from PIL import Image

from denoisingheat.models.ddpm.denoising_diffusion import Unet
from denoisingheat.models.heat.heat_diffusion import HeatDiffusion_Revised
from denoisingheat.utils.utils import (
    gen_goals, overlay_goal, randgen_obstacle_masks, draw_obstacles_pixel, load_config
)
from denoisingheat.utils.diffusion_utils import bilinear_interpolate_samples
import matplotlib.pyplot as plt
from typing import Optional
import shutil


# Args
config_dir = "./denoisingheat/configs/heat_diffusion.yaml"
args = load_config(config_dir)
device = args['device']

bg = Image.open('assets/toy_exp/background0.png')
wastes = []
wastes.append(Image.open('assets/toy_exp/waste0.png'))
# wastes.append(Image.open('assets/toy_exp/waste4.png'))
# wastes.append(Image.open('assets/toy_exp/waste5.png'))


img_size = args['image_size']
goal_bounds = args['goal_bounds']
goal_num = len(wastes)
agent_bounds = args['agent_bounds']
obstacle_pos = args['obstacles']

model_path = os.path.join(args['log_path'], args['model_path'])

u0 = args['u0']
min_heat_step = args['min_heat_step']
max_heat_step = args['max_heat_step']
noise_steps = args['noise_steps']
sample_num = args['sample_num']
time_type = args['time_type']

iterations = args['iterations']
train_lr = args['train_lr']
batch_size = args['batch_size'] #args['batch_size']

  from .autonotebook import tqdm as notebook_tqdm


## Training Setups - Model & Heat-inspired diffusion kernel

In [None]:
class Unet2D(Unet):
    def __init__(
        self, 
        dim, 
        out_dim, 
        dim_mults=(1, 2, 4, 8),
    ):
        super().__init__(dim=dim, out_dim=out_dim, dim_mults=dim_mults)

    def forward(self, obs, t, x_t:Optional[torch.Tensor]=None):
        score_field = super().forward(obs, t)
        if x_t is not None:
            score = bilinear_interpolate_samples(score_field, x_t)
            return score, score_field.permute(0,2,3,1)
        else:
            return score_field.permute(0, 2, 3, 1)
    
model = Unet2D(
    dim=img_size,
    out_dim = 2,
    dim_mults = (1, 2, 4, 8),
).to(device)

diffusion = HeatDiffusion_Revised(
    image_size=img_size,
    u0 = u0,
    noise_steps=noise_steps,
    min_heat_step=min_heat_step,
    max_heat_step=max_heat_step,
    time_type=time_type,
    device=device,
)

optim = torch.optim.Adam(params=model.parameters(), lr=train_lr)

# model.load_state_dict(torch.load(f'./runs/heat/model_params.pt'))

## Training

1. In each iteration, generate obstacles and goals randomly. We train focusing on scenarios with a single goal as well as those involving two goals.
2. Train the model using our heat-inpsired diffusion kernel, ensuring to non-dimensionalize the score throughout training process.

In [None]:
# Train with random single goal
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join('./runs/heat/', current_time)
writer = SummaryWriter(log_dir)
shutil.copy('./denoisingheat/configs/heat_diffusion.yaml', writer.log_dir)

for iters in tqdm(range(iterations)):
    # 1. Generate obstacles and goals randomly
    obstacle_masks = randgen_obstacle_masks(batch_size, img_size)
    background = draw_obstacles_pixel(bg, obstacle_masks)
    multi_goals = gen_goals(goal_bounds, (2,batch_size//2), img_size, obstacles=obstacle_masks[:batch_size//2])
    multi_goal = (torch.rand(batch_size//2, 2, 2, device=device, dtype=torch.float32) * 0.2 - 0.1) * 0.05 + multi_goals
    single_goals = gen_goals(goal_bounds, batch_size//2, img_size, obstacles=obstacle_masks[batch_size//2:])
    single_goal = (torch.rand(batch_size//2, 1, 2, device=device, dtype=torch.float32) * 0.2 - 0.1) * 0.05 + single_goals
    obs1 = overlay_goal(background[:batch_size//2], img_size, wastes, multi_goal)
    obs2 = overlay_goal(background[batch_size//2:], img_size, wastes, single_goal)
    obs = torch.cat((obs1, obs2), dim=0)
    goal = torch.cat((multi_goal[:,0].unsqueeze(1), single_goal), dim=0)
    
    optim.zero_grad()

    losses = []
    target_scores = []
    pred_scores = []
    
    # 2 Train the model
    for i in range(1, noise_steps+1):
        t = (torch.ones(1) * i).long().to(device)
        with torch.no_grad():
            _, score, _, x_t = diffusion.forward_diffusion(t.repeat(batch_size), goal, sample_num, obstacle_masks)
        pred_score, _ = model(obs, t, x_t)
        target_score = score * diffusion.std[i-1]

        squared_diff = torch.sum((pred_score - target_score)**2, dim=-1)

        losses.append(torch.mean(squared_diff))
        target_scores.append(target_score)
        pred_scores.append(pred_score)
    
    loss = sum(losses)/len(losses)
    loss.backward()
    optim.step()
    target_scores = torch.stack(target_scores)
    pred_scores = torch.stack(pred_scores)

    target_score_norm = torch.norm(target_scores.detach(), dim=-1)
    pred_score_norm = torch.norm(pred_scores.detach(), dim=-1)

    dotprod = torch.sum(target_scores.detach() * pred_scores.detach(), dim=-1)
    cosine_sim = dotprod / (torch.norm(target_scores.detach(), dim=-1) * torch.norm(pred_scores.detach(), dim=-1)+1e-8)
    
    writer.add_scalar("Loss/train", loss.item(), iters)
    writer.add_scalar("Norm/target_score", torch.mean(target_score_norm).item(), iters)
    writer.add_scalar("Norm/pred_score", torch.mean(pred_score_norm).item(), iters)
    writer.add_scalar("Similarity/dot_product", torch.mean(dotprod).item(), iters)
    writer.add_scalar("Similarity/cosine_similarity", torch.mean(cosine_sim).item(), iters)
        
    if iters == iterations // 2:
        torch.save(model.state_dict(), os.path.join(writer.log_dir, 'model_params_half.pt'))
    elif iters == int(iterations*3/4):
        torch.save(model.state_dict(), os.path.join(writer.log_dir, 'model_params_half.pt'))

torch.save(model.state_dict(), os.path.join(writer.log_dir, 'model_params.pt'))
            
writer.close()
