In [None]:
import os 
os.chdir('../../')
os.environ["DPM_TQDM"] = "False"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

!nvidia-smi

In [2]:
### Config
from easydict import EasyDict

config = EasyDict()
config.backbone = 'DiT'
config.train_pt_dir = 'samplings/dit/train/dit_train_0_trajdrop'
config.valid_pt_dir = 'samplings/dit/eval1000/dit_eval1000_3'
config.batch_size = 10
config.CFG = 1.375
config.epochs = 10
config.val_every = 100
config.log_dir = "logs/exp13"

### Model
from backbones.dit import DiT

if config.backbone == 'DiT':
    model = DiT(trainable=True)
print(model)


### Dataset
from datasets.pt_dataset import PtDataset
from torch.utils.data import DataLoader

train_dataset = PtDataset(config.train_pt_dir)
valid_dataset = PtDataset(config.valid_pt_dir)
print('len(train_dataset) :', len(train_dataset), 'len(valid_dataset) :', len(valid_dataset))
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True, prefetch_factor=4)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False)
print('done')

### Solver
import torch
from solvers.dual.static.time.gdual_time_pc_loglinear_solver import GDual_Time_PC_LogLinear_Solver
from torch.utils.tensorboard import SummaryWriter

noise_schedule = model.get_noise_schedule()
solver = GDual_Time_PC_LogLinear_Solver(noise_schedule, steps=5, skip_type="time_uniform", param_dim=(4, 1, 1))
solver = solver.to(model.device)
optimizer = torch.optim.AdamW(solver.parameters(), lr=1e-3)
print('done')

Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...:  33%|███▎      | 1/3 [00:00<00:00,  9.11it/s]An error occurred while trying to fetch /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/transformer: Error no file named diffusion_pytorch_model.safetensors found in directory /home/scpark/.cache/huggingface/hub/models--facebook--DiT-XL-2-256/snapshots/eab87f77abd5aef071a632f08807fbaab0b704d0/transformer.
Defaulting to unsafe serial

<backbones.dit.DiT object at 0x7f4cf6260890>
len(train_dataset) : 10000 len(valid_dataset) : 1000
done
done


In [3]:
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

def get_valid_loss(device, solver):
    solver.eval()
    losses = []
    for batch in tqdm(valid_loader):
        with torch.no_grad():
            noises, conds, targets = batch['noise'].to(device, non_blocking=True), batch['cond'], batch['sample'].to(device, non_blocking=True)
            model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)
            losses.append(loss.item())
    return np.mean(losses)
    
def do_train_loop(device, epoch, writer, solver):
    solver.train()
    pbar = tqdm(train_loader)
    losses = []
    for step, batch in enumerate(pbar):
        global_step = epoch * len(train_loader) + step
        if global_step % config.val_every == 0:
            valid_loss = get_valid_loss(device, solver)
            print('step :', global_step, 'valid_loss :', valid_loss)
            writer.add_scalar("valid/loss", valid_loss, global_step)
            save_checkpoint(global_step, config.log_dir, solver, valid_loss)

        optimizer.zero_grad(set_to_none=True)
        noises, conds, targets = batch['noise'].to(device, non_blocking=True), batch['cond'], batch['sample'].to(device, non_blocking=True)
        model_fn = model.get_model_fn(noise_schedule, pos_conds=conds, guidance_scale=config.CFG)

        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            pred = solver.sample(noises, model_fn)
            loss = F.mse_loss(pred, targets)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(solver.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.item())
        pbar.set_postfix({'loss': loss.item()})
        
    return np.mean(losses)

def save_checkpoint(global_step, save_dir, solver, valid_loss):
    ckpt = {
        "global_step": global_step,
        "solver_state_dict": solver.state_dict(),
        "valid_loss": float(valid_loss),
        "config": dict(config),
    }
    step_path = os.path.join(save_dir, f"step_{global_step:08d}.pt")
    torch.save(ckpt, step_path)
    
    return step_path    

print('done')

done


### Train Loop

In [4]:
writer = SummaryWriter(log_dir=config.log_dir)

for epoch in range(config.epochs):
    train_loss = do_train_loop(model.device, epoch, writer, solver)
    print('train_loss :', train_loss)

writer.close()    

100%|██████████| 100/100 [00:27<00:00,  3.64it/s]


step : 0 valid_loss : 0.1140930750221014


100%|██████████| 100/100 [00:26<00:00,  3.73it/s], loss=0.065] 


step : 100 valid_loss : 0.08482537154108286


100%|██████████| 100/100 [00:26<00:00,  3.74it/s], loss=0.0923]  


step : 200 valid_loss : 0.07908805496990681


100%|██████████| 100/100 [00:26<00:00,  3.74it/s], loss=0.088]   


step : 300 valid_loss : 0.07573707185685635


100%|██████████| 100/100 [00:26<00:00,  3.77it/s], loss=0.0646]  


step : 400 valid_loss : 0.07373451173305512


100%|██████████| 100/100 [00:26<00:00,  3.74it/s], loss=0.089]   


step : 500 valid_loss : 0.07156141437590122


100%|██████████| 100/100 [00:26<00:00,  3.73it/s], loss=0.0679]  


step : 600 valid_loss : 0.07193665198981762


100%|██████████| 100/100 [00:26<00:00,  3.75it/s], loss=0.0632]  


step : 700 valid_loss : 0.07042043197900057


100%|██████████| 100/100 [00:26<00:00,  3.73it/s], loss=0.0902]


step : 800 valid_loss : 0.06998531956225634


100%|██████████| 100/100 [00:26<00:00,  3.72it/s], loss=0.0909]


step : 900 valid_loss : 0.0704512320831418


100%|██████████| 1000/1000 [25:10<00:00,  1.51s/it, loss=0.0412]


train_loss : 0.07615531568974257


100%|██████████| 100/100 [00:26<00:00,  3.72it/s]


step : 1000 valid_loss : 0.07047237001359463


100%|██████████| 100/100 [00:26<00:00,  3.75it/s], loss=0.0731]


step : 1100 valid_loss : 0.06935835927724839


100%|██████████| 100/100 [00:26<00:00,  3.75it/s], loss=0.0571]  


step : 1200 valid_loss : 0.06914363890886306


100%|██████████| 100/100 [00:26<00:00,  3.77it/s], loss=0.0555]  


step : 1300 valid_loss : 0.07009186170995235


100%|██████████| 100/100 [00:26<00:00,  3.72it/s], loss=0.085]   


step : 1400 valid_loss : 0.07001779638230801


100%|██████████| 100/100 [00:26<00:00,  3.81it/s], loss=0.0706]  


step : 1500 valid_loss : 0.07003508467227221


100%|██████████| 100/100 [00:26<00:00,  3.73it/s], loss=0.0637]  


step : 1600 valid_loss : 0.06997304819524289


100%|██████████| 100/100 [00:26<00:00,  3.79it/s], loss=0.0625]  


step : 1700 valid_loss : 0.06976748995482922


 77%|███████▋  | 767/1000 [19:23<05:53,  1.52s/it, loss=0.0788]


KeyboardInterrupt: 