In [1]:
import os
os.chdir('../../')

In [5]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import random

from main import parse_args_and_config, Diffusion
from datasets import inverse_data_transform

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.benchmark = False

###############################################################################
# 1) Notebook에서 sys.argv를 직접 설정 (argparse 흉내)
###############################################################################
sys.argv = [
    "main.py",
    "--config", "imagenet128_guided.yml",  # 사용하려는 config
    "--sample",
    "--eta", "0",
    "--sample_type", "lagrangesolver",
    "--dpm_solver_type", "data_prediction",
    "--dpm_solver_order", "1",
    "--timesteps", "10",
    "--skip_type", "logSNR",
    "--scale", "0.0",
    "--thresholding",
    "--ni"
]

###############################################################################
# 2) 인자/설정 로드
###############################################################################
args, config = parse_args_and_config()

###############################################################################
# 3) Diffusion 객체 생성 -> 모델 로딩
###############################################################################
diffusion = Diffusion(args, config, rank=0)
diffusion.prepare_model()
diffusion.model.eval()

###############################################################################
# 4) 배치(25장) 한 번에 샘플링 -> 5x5 그리드(여백 없이) 시각화
###############################################################################
device = diffusion.device

INFO - main.py - 2025-03-30 12:39:11,260 - Using device: cuda
INFO - main.py - 2025-03-30 12:39:11,260 - Using device: cuda


[prepare_model] Model is ready.


In [6]:
from tqdm import tqdm

def sample_in_batches(diffusion, config, device, total_samples=1024, batch_size=16):
    with torch.no_grad():
        pairs = []
        class_list = []
        for i in tqdm(range(0, total_samples, batch_size)):
            bs = min(batch_size, total_samples - i)
            noise = np.random.randn(bs, config.data.channels, config.data.image_size, config.data.image_size).astype(np.float32)
            noise = torch.tensor(noise, device=device)
            classes = np.random.randint(0, config.data.num_classes, size=(noise.shape[0],))
            classes = torch.tensor(classes).to(device)
            data, _ = diffusion.sample_image(noise, diffusion.model, classes=classes)
            pair = torch.stack([noise, data], dim=1)
            pairs.append(pair)
            class_list.append(classes)
    return torch.cat(pairs, dim=0), torch.cat(class_list, dim=0)

total_samples = 16
pairs, classes = sample_in_batches(diffusion, config, device, total_samples=total_samples, batch_size=4)
print(pairs.shape)
print(classes)

100%|██████████| 4/4 [00:01<00:00,  2.20it/s]

torch.Size([16, 2, 3, 128, 128])
tensor([348, 940, 965,  52,  47, 108, 901, 711, 168, 530,  67, 646, 855, 978,
         50, 578], device='cuda:0')





In [10]:
!mkdir -p /data/optimization/
save_file = f'/data/optimization/euler_NFE=1000_N={total_samples}_imagenet128.pt'
torch.save({'pairs': pairs.data.cpu(),
            'classes': classes.data.cpu()
            }, save_file)
pairs_load = torch.load(save_file)
print(pairs_load['pairs'].shape, pairs_load['classes'].shape)

torch.Size([16, 2, 3, 128, 128]) torch.Size([16])
