In [1]:
import os 
os.chdir('../../')
os.environ["DPM_TQDM"] = "False"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

!nvidia-smi

Sun Aug 10 00:48:06 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:19:00.0 Off |                  Off |
| 32%   53C    P8             40W /  450W |    2835MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00

In [2]:
### Config
from easydict import EasyDict

config = EasyDict()
config.backbone = 'DiT'
config.train_pt_dir = 'samplings/dit/train/dit_train_0_trajdrop'
config.valid_pt_dir = 'samplings/dit/eval1000/dit_eval1000_3'
config.batch_size = 10
config.CFG = 1.375
config.epochs = 10
config.val_every = 100
config.log_dir = "logs/exp6"

### Model
from backbones.dit import DiT

if config.backbone == 'DiT':
    model = DiT(trainable=True)
print(model)


### Dataset
from datasets.pt_dataset import PtDataset
from torch.utils.data import DataLoader

train_dataset = PtDataset(config.train_pt_dir)
valid_dataset = PtDataset(config.valid_pt_dir)
print('len(train_dataset) :', len(train_dataset), 'len(valid_dataset) :', len(valid_dataset))
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True, prefetch_factor=4)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)
print('done')

### Solver
import torch
from solvers.dual.static.gdual_pc_loglinear_solver import GDual_PC_LogLinear_Solver
from torch.utils.tensorboard import SummaryWriter

noise_schedule = model.get_noise_schedule()
solver = GDual_PC_LogLinear_Solver(noise_schedule, steps=5, skip_type="time_uniform", param_dim=(1, 1, 1))
solver = solver.to(model.device)
optimizer = torch.optim.AdamW(solver.parameters(), lr=1e-3)
print('done')

  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/transformer: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/transformer.
Defaulting to unsafe serialization. Pass `allow_pickle

<backbones.dit.DiT object at 0x7fb5001affd0>
len(train_dataset) : 10000 len(valid_dataset) : 1000
done
done


In [3]:
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

def get_valid_loss(device, solver):
    solver.eval()
    losses = []
    for batch in tqdm(valid_loader):
        with torch.no_grad():
            noises, conds, targets = batch['noise'].to(device, non_blocking=True), batch['cond'], batch['sample'].to(device, non_blocking=True)
            model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)
            losses.append(loss.item())
    return np.mean(losses)
    
def do_train_loop(device, epoch, writer, solver):
    solver.train()
    pbar = tqdm(train_loader)
    losses = []
    for step, batch in enumerate(pbar):
        global_step = epoch * len(train_loader) + step
        if global_step % config.val_every == 0:
            valid_loss = get_valid_loss(device, solver)
            print('step :', global_step, 'valid_loss :', valid_loss)
            writer.add_scalar("valid/loss", valid_loss, global_step)
            save_checkpoint(global_step, config.log_dir, solver, valid_loss)

        optimizer.zero_grad(set_to_none=True)
        noises, conds, targets = batch['noise'].to(device, non_blocking=True), batch['cond'], batch['sample'].to(device, non_blocking=True)
        model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)

        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(solver.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.item())
        pbar.set_postfix({'loss': loss.item()})
        
    return np.mean(losses)

def save_checkpoint(global_step, save_dir, solver, valid_loss):
    ckpt = {
        "global_step": global_step,
        "solver_state_dict": solver.state_dict(),
        "valid_loss": float(valid_loss),
        "config": dict(config),
    }
    step_path = os.path.join(save_dir, f"step_{global_step:08d}.pt")
    torch.save(ckpt, step_path)
    
    return step_path    

print('done')

done


### Train Loop

In [4]:
writer = SummaryWriter(log_dir=config.log_dir)

for epoch in range(config.epochs):
    train_loss = do_train_loop(model.device, epoch, writer, solver)
    print('train_loss :', train_loss)

writer.close()    

100%|██████████| 100/100 [00:28<00:00,  3.49it/s]


step : 0 valid_loss : 0.11410509288311005


100%|██████████| 100/100 [00:28<00:00,  3.55it/s], loss=0.0857]


step : 100 valid_loss : 0.08776374734938144


100%|██████████| 100/100 [00:28<00:00,  3.56it/s], loss=0.082]   


step : 200 valid_loss : 0.08377187073230744


100%|██████████| 100/100 [00:27<00:00,  3.58it/s], loss=0.111]   


step : 300 valid_loss : 0.08086048521101474


100%|██████████| 100/100 [00:28<00:00,  3.52it/s], loss=0.0923]  


step : 400 valid_loss : 0.07900686115026474


100%|██████████| 100/100 [00:28<00:00,  3.49it/s], loss=0.091]   


step : 500 valid_loss : 0.07829628963023424


100%|██████████| 100/100 [00:28<00:00,  3.48it/s], loss=0.0927]  


step : 600 valid_loss : 0.07768074866384268


100%|██████████| 100/100 [00:28<00:00,  3.49it/s], loss=0.0738]  


step : 700 valid_loss : 0.07762893706560135


100%|██████████| 100/100 [00:29<00:00,  3.44it/s], loss=0.0657]


step : 800 valid_loss : 0.07708251472562551


100%|██████████| 100/100 [00:32<00:00,  3.10it/s], loss=0.0448]


step : 900 valid_loss : 0.07698146138340235


100%|██████████| 1000/1000 [26:01<00:00,  1.56s/it, loss=0.0662]


train_loss : 0.08180694698914885


100%|██████████| 100/100 [00:34<00:00,  2.91it/s]


step : 1000 valid_loss : 0.07677858244627714


100%|██████████| 100/100 [00:38<00:00,  2.61it/s], loss=0.0644]


step : 1100 valid_loss : 0.07681643418967724


100%|██████████| 100/100 [00:42<00:00,  2.37it/s], loss=0.0801]  


step : 1200 valid_loss : 0.07674754127860069


100%|██████████| 100/100 [00:43<00:00,  2.30it/s], loss=0.0822]  


step : 1300 valid_loss : 0.07663685970008373


100%|██████████| 100/100 [00:46<00:00,  2.17it/s], loss=0.0671]  


step : 1400 valid_loss : 0.07661189042031764


100%|██████████| 100/100 [00:48<00:00,  2.08it/s], loss=0.0767]  


step : 1500 valid_loss : 0.07661412976682186


100%|██████████| 100/100 [00:47<00:00,  2.09it/s], loss=0.0604]  


step : 1600 valid_loss : 0.07697153184562922


100%|██████████| 100/100 [00:48<00:00,  2.06it/s], loss=0.0694]  


step : 1700 valid_loss : 0.07669193528592587


 78%|███████▊  | 778/1000 [32:27<09:15,  2.50s/it, loss=0.0626]  


KeyboardInterrupt: 