In [1]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from loader.pouring_dataset import Pouring
from models.ConditionalUNet1D import ConditionalUnet1D

In [2]:
num_diffusion_iters = 100
num_epochs = 2500
batch_size = 10
input_dim=12
device = torch.device('cuda')
save_path = "./params/pouring_dataset/basic_DDPM/"

In [3]:
dataset = Pouring()
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True
)

batch = next(iter(dataloader))
print("batch['traj'].shape:",batch[0].shape)

Pouring dataset is ready; # of trajectories: 10
batch['traj'].shape: torch.Size([10, 480, 4, 4])


In [4]:
# Network
noise_pred_net = ConditionalUnet1D(
    input_dim=input_dim,
    global_cond_dim=0
)

# DDPM scheduler
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # Squared cosine
    beta_schedule='squaredcos_cap_v2',
    # Clip output to [-1,1]
    clip_sample=True,
    prediction_type='epsilon'
)

_ = noise_pred_net.to(device)

# Exponential Moving Average
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75
)

# ADAM optimizer
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4,
    weight_decay=1e-6
)

# Consine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

In [5]:
with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    # epoch loop
    for epoch_idx in tglobal:
        epoch_loss = list()
        # batch loop
        with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
            for nbatch in tepoch:
                # SE3 to vec
                ntraj = nbatch[0].to(device)
                B, traj_len, _, _ = ntraj.shape
                ntraj = ntraj[:, :, :3, :]
                ntraj = ntraj.reshape(B, traj_len, -1)

                # Sample noise
                noise = torch.randn(ntraj.shape, device=device)

                # Sample a diffusion iteration for each data point
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps,
                    (B,), device=device
                ).long()

                # Forward diffusion process
                noisy_traj = noise_scheduler.add_noise(
                    ntraj, noise, timesteps
                )

                # Predict the noise residual
                noise_pred = noise_pred_net(
                    noisy_traj, timesteps
                )
                
                # L2 loss
                loss = nn.functional.mse_loss(noise_pred, noise)

                # Optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step()

                # Update EMP
                ema.step(noise_pred_net.parameters())

                # Logging
                loss_cpu = loss.item()
                epoch_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)
        tglobal.set_postfix(loss=np.mean(epoch_loss))

# Weights of the EMA model
ema_noise_pred_net = noise_pred_net
ema.copy_to(ema_noise_pred_net.parameters())

# Save EMA model
torch.save(ema_noise_pred_net, save_path + "model_ep" + str(num_epochs) + ".pt")

Epoch: 100%|██████████| 2500/2500 [04:27<00:00,  9.33it/s, loss=0.00666]
