In [1]:
import os 
os.chdir('../../')
os.environ["DPM_TQDM"] = "False"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

!nvidia-smi

Sat Aug  9 19:58:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:19:00.0 Off |                  Off |
| 62%   72C    P2            332W /  450W |   14281MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00

In [2]:
### Config
from easydict import EasyDict

config = EasyDict()
config.backbone = 'DiT'
config.train_pt_dir = 'samplings/dit/train/dit_train_0'
config.valid_pt_dir = 'samplings/dit/eval1000/dit_eval1000_3'
config.batch_size = 10
config.CFG = 1.375
config.epochs = 10
config.val_every = 100
config.log_dir = "logs/exp3"

### Model
from backbones.dit import DiT

if config.backbone == 'DiT':
    model = DiT(trainable=True)
print(model)


### Dataset
from datasets.pt_dataset import PtDataset
from torch.utils.data import DataLoader

train_dataset = PtDataset(config.train_pt_dir)
valid_dataset = PtDataset(config.valid_pt_dir)
print('len(train_dataset) :', len(train_dataset), 'len(valid_dataset) :', len(valid_dataset))
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)
print('done')

### Solver
import torch
from solvers.dual.static.gdual_pc_box_solver import GDual_PC_Box_Solver
from torch.utils.tensorboard import SummaryWriter

noise_schedule = model.get_noise_schedule()
solver = GDual_PC_Box_Solver(noise_schedule, steps=5, skip_type="time_uniform", param_dim=(1, 1, 1))
solver = solver.to(model.device)
optimizer = torch.optim.AdamW(solver.parameters(), lr=1e-3)
print('done')

  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/transformer: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/transformer.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...:  33%|███▎      | 1/3 [00:04<00:08,  4.41s/it]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f

<backbones.dit.DiT object at 0x7fca58392410>
len(train_dataset) : 10000 len(valid_dataset) : 1000
done
done


In [3]:
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

def get_valid_loss(device, solver):
    solver.eval()
    losses = []
    for batch in valid_loader:
        with torch.no_grad():
            noises, conds, targets = batch['noise'].to(device), batch['cond'], batch['sample'].to(device)
            model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)
            losses.append(loss.item())
    return np.mean(losses)
    
def do_train_loop(device, epoch, writer, solver):
    solver.train()
    pbar = tqdm(train_loader)
    losses = []
    for step, batch in enumerate(pbar):
        global_step = epoch * len(train_loader) + step
        if global_step % config.val_every == 0:
            valid_loss = get_valid_loss(device, solver)
            print('step :', global_step, 'valid_loss :', valid_loss)
            writer.add_scalar("valid/loss", valid_loss, global_step)
            save_checkpoint(global_step, config.log_dir, solver, valid_loss)

        optimizer.zero_grad(set_to_none=True)
        noises, conds, targets = batch['noise'].to(device), batch['cond'], batch['sample'].to(device)
        model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)

        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(solver.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.item())
        pbar.set_postfix({'loss': loss.item()})
        
    return np.mean(losses)

def save_checkpoint(global_step, save_dir, solver, valid_loss):
    ckpt = {
        "global_step": global_step,
        "solver_state_dict": solver.state_dict(),
        "valid_loss": float(valid_loss),
        "config": dict(config),
    }
    step_path = os.path.join(save_dir, f"step_{global_step:08d}.pt")
    torch.save(ckpt, step_path)
    
    return step_path    

print('done')

done


### Train Loop

In [4]:
writer = SummaryWriter(log_dir=config.log_dir)

for epoch in range(config.epochs):
    train_loss = do_train_loop(model.device, epoch, writer, solver)
    print('train_loss :', train_loss)

writer.close()    

  0%|          | 0/1000 [00:00<?, ?it/s]

step : 0 valid_loss : 0.11820850789546966


 10%|█         | 100/1000 [08:28<1:21:57,  5.46s/it, loss=0.0862]

step : 100 valid_loss : 0.08689071375876666


 20%|██        | 200/1000 [17:44<1:11:49,  5.39s/it, loss=0.075] 

step : 200 valid_loss : 0.08265698228031397


 30%|███       | 300/1000 [26:59<1:12:36,  6.22s/it, loss=0.0834]

step : 300 valid_loss : 0.08111071195453405


 40%|████      | 400/1000 [36:02<51:39,  5.17s/it, loss=0.0598]  

step : 400 valid_loss : 0.07948151193559169


 50%|█████     | 500/1000 [45:06<50:53,  6.11s/it, loss=0.0861]  

step : 500 valid_loss : 0.07887227360159159


 60%|██████    | 600/1000 [54:21<40:24,  6.06s/it, loss=0.064]   

step : 600 valid_loss : 0.07827108830213547


 70%|███████   | 700/1000 [1:03:09<28:07,  5.63s/it, loss=0.0834]

step : 700 valid_loss : 0.07768917866051198


 80%|████████  | 800/1000 [1:12:03<17:05,  5.13s/it, loss=0.0669]  

step : 800 valid_loss : 0.07723396506160497


 90%|█████████ | 900/1000 [1:18:57<05:37,  3.38s/it, loss=0.0515]

step : 900 valid_loss : 0.07707419849932194


100%|██████████| 1000/1000 [1:28:17<00:00,  5.30s/it, loss=0.0598]


train_loss : 0.08141245606169105


  0%|          | 0/1000 [00:00<?, ?it/s]

step : 1000 valid_loss : 0.07693271547555923


 10%|█         | 100/1000 [05:25<44:37,  2.97s/it, loss=0.0736]

step : 1100 valid_loss : 0.07656478226184844


 20%|██        | 200/1000 [11:04<42:21,  3.18s/it, loss=0.0924]  

step : 1200 valid_loss : 0.07663572147488594


 24%|██▎       | 236/1000 [13:40<44:16,  3.48s/it, loss=0.066]   


KeyboardInterrupt: 