In [1]:
import os 
os.chdir('../../')
os.environ["DPM_TQDM"] = "False"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

!nvidia-smi

Sun Aug 10 02:56:47 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:19:00.0 Off |                  Off |
| 88%   81C    P2            383W /  450W |   14279MiB /  24564MiB |     52%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00

In [2]:
### Config
from easydict import EasyDict

config = EasyDict()
config.backbone = 'DiT'
config.train_pt_dir = 'samplings/dit/train/dit_train_0_trajdrop'
config.valid_pt_dir = 'samplings/dit/eval1000/dit_eval1000_3'
config.batch_size = 10
config.CFG = 1.375
config.epochs = 10
config.val_every = 100
config.log_dir = "logs/exp12"

### Model
from backbones.dit import DiT

if config.backbone == 'DiT':
    model = DiT(trainable=True)
print(model)


### Dataset
from datasets.pt_dataset import PtDataset
from torch.utils.data import DataLoader

train_dataset = PtDataset(config.train_pt_dir)
valid_dataset = PtDataset(config.valid_pt_dir)
print('len(train_dataset) :', len(train_dataset), 'len(valid_dataset) :', len(valid_dataset))
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True, prefetch_factor=4)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)
print('done')

### Solver
import torch
from solvers.dual.static.data_coeff2_solver import Data_Coeff2_Solver
from torch.utils.tensorboard import SummaryWriter

noise_schedule = model.get_noise_schedule()
solver = Data_Coeff2_Solver(noise_schedule, steps=5, skip_type="time_uniform")
solver = solver.to(model.device)
optimizer = torch.optim.AdamW(solver.parameters(), lr=1e-3)
print('done')

  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...:  67%|██████▋   | 2/3 [00:00<00:00, 13.07it/s]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/transformer: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fba

<backbones.dit.DiT object at 0x7fbf64fdfc50>
len(train_dataset) : 10000 len(valid_dataset) : 1000
done
done


In [3]:
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

def get_valid_loss(device, solver):
    solver.eval()
    losses = []
    for batch in tqdm(valid_loader):
        with torch.no_grad():
            noises, conds, targets = batch['noise'].to(device, non_blocking=True), batch['cond'], batch['sample'].to(device, non_blocking=True)
            model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)
            losses.append(loss.item())
    return np.mean(losses)
    
def do_train_loop(device, epoch, writer, solver):
    solver.train()
    pbar = tqdm(train_loader)
    losses = []
    for step, batch in enumerate(pbar):
        global_step = epoch * len(train_loader) + step
        if global_step % config.val_every == 0:
            valid_loss = get_valid_loss(device, solver)
            print('step :', global_step, 'valid_loss :', valid_loss)
            writer.add_scalar("valid/loss", valid_loss, global_step)
            save_checkpoint(global_step, config.log_dir, solver, valid_loss)

        optimizer.zero_grad(set_to_none=True)
        noises, conds, targets = batch['noise'].to(device, non_blocking=True), batch['cond'], batch['sample'].to(device, non_blocking=True)
        model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)

        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(solver.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.item())
        pbar.set_postfix({'loss': loss.item()})
        
    return np.mean(losses)

def save_checkpoint(global_step, save_dir, solver, valid_loss):
    ckpt = {
        "global_step": global_step,
        "solver_state_dict": solver.state_dict(),
        "valid_loss": float(valid_loss),
        "config": dict(config),
    }
    step_path = os.path.join(save_dir, f"step_{global_step:08d}.pt")
    torch.save(ckpt, step_path)
    
    return step_path    

print('done')

done


### Train Loop

In [None]:
writer = SummaryWriter(log_dir=config.log_dir)

for epoch in range(config.epochs):
    train_loss = do_train_loop(model.device, epoch, writer, solver)
    print('train_loss :', train_loss)

writer.close()    

100%|██████████| 100/100 [00:26<00:00,  3.81it/s]


step : 0 valid_loss : 0.8954178720712662


100%|██████████| 100/100 [00:42<00:00,  2.38it/s], loss=0.928]


step : 100 valid_loss : 0.8722517251968384


100%|██████████| 100/100 [00:49<00:00,  2.02it/s], loss=0.945]  


step : 200 valid_loss : 0.7222370794415474


100%|██████████| 100/100 [00:50<00:00,  1.99it/s], loss=0.448]  


step : 300 valid_loss : 0.5820503744482994


100%|██████████| 100/100 [00:52<00:00,  1.92it/s], loss=0.405]  


step : 400 valid_loss : 0.41418111145496367


100%|██████████| 100/100 [00:51<00:00,  1.94it/s], loss=0.357]  


step : 500 valid_loss : 0.3524515289068222


100%|██████████| 100/100 [00:54<00:00,  1.82it/s], loss=0.335]  


step : 600 valid_loss : 0.3205217844247818


100%|██████████| 100/100 [00:56<00:00,  1.76it/s], loss=0.259]  


step : 700 valid_loss : 0.2958448077738285


100%|██████████| 100/100 [00:52<00:00,  1.92it/s], loss=0.27]   


step : 800 valid_loss : 0.27345385566353797


 89%|████████▉ | 893/1000 [42:43<04:30,  2.53s/it, loss=0.271]  