In [1]:
import os 
os.chdir('../../')
os.environ["DPM_TQDM"] = "False"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

!nvidia-smi

Sun Aug 10 02:50:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:19:00.0 Off |                  Off |
|  0%   75C    P8             56W /  450W |      11MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00

In [2]:
### Config
from easydict import EasyDict

config = EasyDict()
config.backbone = 'DiT'
config.train_pt_dir = 'samplings/dit/train/dit_train_0_trajdrop'
config.valid_pt_dir = 'samplings/dit/eval1000/dit_eval1000_3'
config.batch_size = 10
config.CFG = 1.375
config.epochs = 10
config.val_every = 100
config.log_dir = "logs/exp11"

### Model
from backbones.dit import DiT

if config.backbone == 'DiT':
    model = DiT(trainable=True)
print(model)


### Dataset
from datasets.pt_dataset import PtDataset
from torch.utils.data import DataLoader

train_dataset = PtDataset(config.train_pt_dir)
valid_dataset = PtDataset(config.valid_pt_dir)
print('len(train_dataset) :', len(train_dataset), 'len(valid_dataset) :', len(valid_dataset))
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True, prefetch_factor=4)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)
print('done')

### Solver
import torch
from solvers.dual.static.gdual_coeff2_solver import GDual_Coeff2_Solver
from torch.utils.tensorboard import SummaryWriter

noise_schedule = model.get_noise_schedule()
solver = GDual_Coeff2_Solver(noise_schedule, steps=5, skip_type="time_uniform")
solver = solver.to(model.device)
optimizer = torch.optim.AdamW(solver.parameters(), lr=1e-3)
print('done')

  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...:  33%|███▎      | 1/3 [00:00<00:00,  6.91it/s]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/transformer: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fba

<backbones.dit.DiT object at 0x7fd0842f1cd0>
len(train_dataset) : 10000 len(valid_dataset) : 1000
done
done


In [3]:
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

def get_valid_loss(device, solver):
    solver.eval()
    losses = []
    for batch in tqdm(valid_loader):
        with torch.no_grad():
            noises, conds, targets = batch['noise'].to(device, non_blocking=True), batch['cond'], batch['sample'].to(device, non_blocking=True)
            model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)
            losses.append(loss.item())
    return np.mean(losses)
    
def do_train_loop(device, epoch, writer, solver):
    solver.train()
    pbar = tqdm(train_loader)
    losses = []
    for step, batch in enumerate(pbar):
        global_step = epoch * len(train_loader) + step
        if global_step % config.val_every == 0:
            valid_loss = get_valid_loss(device, solver)
            print('step :', global_step, 'valid_loss :', valid_loss)
            writer.add_scalar("valid/loss", valid_loss, global_step)
            save_checkpoint(global_step, config.log_dir, solver, valid_loss)

        optimizer.zero_grad(set_to_none=True)
        noises, conds, targets = batch['noise'].to(device, non_blocking=True), batch['cond'], batch['sample'].to(device, non_blocking=True)
        model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)

        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(solver.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.item())
        pbar.set_postfix({'loss': loss.item()})
        
    return np.mean(losses)

def save_checkpoint(global_step, save_dir, solver, valid_loss):
    ckpt = {
        "global_step": global_step,
        "solver_state_dict": solver.state_dict(),
        "valid_loss": float(valid_loss),
        "config": dict(config),
    }
    step_path = os.path.join(save_dir, f"step_{global_step:08d}.pt")
    torch.save(ckpt, step_path)
    
    return step_path    

print('done')

done


### Train Loop

In [None]:
writer = SummaryWriter(log_dir=config.log_dir)

for epoch in range(config.epochs):
    train_loss = do_train_loop(model.device, epoch, writer, solver)
    print('train_loss :', train_loss)

writer.close()    

100%|██████████| 100/100 [00:25<00:00,  3.87it/s]


step : 0 valid_loss : 0.8954178720712662


100%|██████████| 100/100 [00:24<00:00,  4.05it/s], loss=0.958]


step : 100 valid_loss : 0.8613730448484421


100%|██████████| 100/100 [00:24<00:00,  4.05it/s], loss=0.464]  


step : 200 valid_loss : 0.5256346797943116


100%|██████████| 100/100 [00:25<00:00,  3.98it/s], loss=0.233]  


step : 300 valid_loss : 0.23803266108036042


100%|██████████| 100/100 [00:25<00:00,  3.90it/s], loss=0.155]  


step : 400 valid_loss : 0.1702463436871767


100%|██████████| 100/100 [00:26<00:00,  3.82it/s], loss=0.13]   


step : 500 valid_loss : 0.14142782285809516


100%|██████████| 100/100 [00:27<00:00,  3.61it/s], loss=0.1]    


step : 600 valid_loss : 0.12069486267864704


100%|██████████| 100/100 [00:28<00:00,  3.52it/s], loss=0.119]   


step : 700 valid_loss : 0.10985560283064842


100%|██████████| 100/100 [00:29<00:00,  3.39it/s], loss=0.1]   


step : 800 valid_loss : 0.10481371074914932


100%|██████████| 100/100 [00:29<00:00,  3.35it/s], loss=0.102] 


step : 900 valid_loss : 0.1008191865682602


100%|██████████| 1000/1000 [26:47<00:00,  1.61s/it, loss=0.095]


train_loss : 0.2859144819751382


100%|██████████| 100/100 [00:30<00:00,  3.30it/s]


step : 1000 valid_loss : 0.09905460372567176


100%|██████████| 100/100 [00:30<00:00,  3.27it/s], loss=0.0671]


step : 1100 valid_loss : 0.09643468622118234


100%|██████████| 100/100 [00:30<00:00,  3.26it/s], loss=0.0998]  


step : 1200 valid_loss : 0.09486889738589525


100%|██████████| 100/100 [00:30<00:00,  3.25it/s], loss=0.109]   


step : 1300 valid_loss : 0.0934385309368372


100%|██████████| 100/100 [00:30<00:00,  3.25it/s], loss=0.0786]  


step : 1400 valid_loss : 0.09254572451114655


100%|██████████| 100/100 [00:31<00:00,  3.19it/s], loss=0.0792]  


step : 1500 valid_loss : 0.0919451680406928


100%|██████████| 100/100 [00:31<00:00,  3.16it/s], loss=0.0885]  


step : 1600 valid_loss : 0.09115485396236181


100%|██████████| 100/100 [00:32<00:00,  3.11it/s], loss=0.0983]  


step : 1700 valid_loss : 0.0905314028263092


100%|██████████| 100/100 [00:32<00:00,  3.11it/s], loss=0.107] 


step : 1800 valid_loss : 0.08994674693793059


100%|██████████| 100/100 [00:31<00:00,  3.14it/s], loss=0.0776]


step : 1900 valid_loss : 0.08939592998474837


 92%|█████████▏| 920/1000 [28:54<02:11,  1.64s/it, loss=0.117] 