In [1]:
import os
os.chdir('../')

In [None]:
import numpy as np
from easydict import EasyDict

from pathlib import Path
def get_sampling_dir(config) -> Path:
    parts = [
        config.data,
        config.solver,
        config.algorithm_type,
        config.skip_type,
        f"FS{config.flow_shift}",
        f"NFE{config.NFE}",
        f"CFG{config.CFG}",
        f"ORDER{config.order}",
    ]
    dir_name = config.model + "".join(f"({p})" for p in parts)
    sampling_dir = Path(config.save_root) / dir_name
    return sampling_dir

config = EasyDict()
config.model = 'SANA'
config.solver = 'Euler'
config.algorithm_type = 'vector_prediction'
config.skip_type = 'time_uniform_flow'
config.flow_shift = 1.0
config.NFE = 10
config.CFG = 4.5
config.order = 2
config.data = 'MSCOCO2017'
config.save_root = '/data/scpark/samplings/'
config.n_samples = 10
config.batch_size = 5
config.save_dir = get_sampling_dir(config)

config

{'model': 'SANA',
 'solver': 'Euler',
 'algorithm_type': 'vector_prediction',
 'skip_type': 'time_uniform_flow',
 'flow_shift': 1.0,
 'NFE': 10,
 'CFG': 4.5,
 'order': 2,
 'data': 'MSCOCO2017',
 'save_root': '/data/scpark/samplings/',
 'n_samples': 10,
 'batch_size': 5,
 'save_dir': PosixPath('/data/scpark/samplings/SANA(MSCOCO2017)(Euler)(vector_prediction)(time_uniform_flow)(FS1.0)(NFE10)(CFG4.5)(ORDER2)')}

In [None]:
def get_model(config):
    if config.model == 'SANA':
        from backbones.sana import SANA
        model = SANA()
    return model

def get_solver(config):
    if config.solver == 'Euler':
        from solvers.euler_solver import Euler_Solver
        Solver = Euler_Solver
    return Solver

def get_data(config):
    if config.data == 'MSCOCO2017':
        data = np.load('prompts/mscoco2017.npz')['arr_0'].tolist()
    return data


In [8]:
model, Solver, data = get_model(config), get_solver(config), get_data(config)
print('done') 

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 10.47it/s]
Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  5.35it/s]


done


In [None]:
import os
import math
import torch
from tqdm import tqdm

os.makedirs(config.save_dir, exist_ok=True)

# 총 반복 횟수 계산
n_iters = math.ceil(config.n_samples / config.batch_size)

for start in tqdm(range(0, config.n_samples, config.batch_size), 
                  total=n_iters, 
                  desc="Sampling batches"):
    end = min(start + config.batch_size, config.n_samples)

    prompts = data[start:end]
    seeds = list(range(start, end))
    model_fn, noise_schedule, latents = model.get_model_fn(
        pos_texts=prompts,
        guidance_scale=config.CFG,
        seeds=seeds
    )

    solver = Solver(
        model_fn,
        noise_schedule,
        algorithm_type=config.algorithm_type
    )
    latent_samples = solver.sample(
        latents,
        steps=config.NFE,
        order=config.order,
        skip_type=config.skip_type,
        flow_shift=config.flow_shift
    )
    latent_samples = latent_samples.data.cpu()
    torch.save(latent_samples, config.save_dir / f"{start}:{end}.pt")


Sampling batches:   0%|          | 0/2 [00:00<?, ?it/s]

Processing 0 to 5


Sampling batches:  50%|█████     | 1/2 [00:03<00:03,  3.97s/it]

Processing 5 to 10


Sampling batches: 100%|██████████| 2/2 [00:07<00:00,  3.96s/it]
