In [2]:
import os
import numpy as np
from tqdm import tqdm
max_files = 40
#max_files = 1

def load_and_stack_samples(directory, max_files):
    read_files = [f"samples_{i}.npz" for i in range(max_files)]
    npz_files = [os.path.join(directory, f) for f in read_files]
    samples_list = []
    for f in npz_files:
        if not os.path.exists(f):
            return None
        with np.load(f) as data:
            if not 'samples_raw' in data:
                return None
            samples_list.append(data['samples_raw'].astype(np.float32))
    return np.concatenate(samples_list, axis=0) if samples_list else None

def load_samples_for_all_subdirs(root_dir, max_dirs, max_files, hint=None):
    result_dict = {}
    for subdir in tqdm(sorted(os.listdir(root_dir))[:max_dirs]):
        if hint is not None:
            if not hint in subdir:
                continue
        subdir_path = os.path.join(root_dir, subdir)
        if os.path.isdir(subdir_path):
            result = load_and_stack_samples(subdir_path, max_files)
            if result is not None:
                result_dict[subdir] = result
    return result_dict

# 사용 예시:
directory_path = "/data/score_sde_outputs/checkpoint_8/lagrange_10k200"
reference = load_and_stack_samples(directory_path, max_files=max_files)
print(reference.shape)


(10240, 3, 32, 32)


In [17]:
root_dir = '/data/score_sde_outputs/checkpoint_8'
result_dict = load_samples_for_all_subdirs(root_dir, max_dirs=1000, max_files=max_files, hint='rbf_25')
print(len(result_dict))

100%|██████████| 281/281 [00:00<00:00, 288.77it/s]

1





In [18]:
pairs = []
for key, arr in result_dict.items():
    if arr.shape == reference.shape:
        l2 = np.mean((reference - arr)**2) * 10000
        pairs.append((key, l2))

pairs.sort(key=lambda x: x[0])

for k, v in pairs:
    print(k, v)


rbf_25 1.212975475937128


In [4]:
pairs = []
for key, arr in result_dict.items():
    if arr.shape == reference.shape:
        l2 = np.mean((reference - arr)**2) * 10000
        pairs.append((key, l2))

pairs.sort(key=lambda x: x[1])

for k, v in pairs:
    if k.endswith('_5'):
        print(k, v)


rbf_const_optimal_5 43.037617579102516
rbf_spd_const_5 44.253626838326454
rbf_const_grid_optimal_5 44.49122119694948
rbf_ecp_optimal_5 44.58761774003506
rbf_ecp_same4_5 44.629620388150215
rbf_ecp_same5_5 44.629620388150215
rbf_ecp_same_optimal_5 44.629620388150215
rbf_xt_5 44.689225032925606
rbf_gram_5 45.18920090049505
lagrange_5 45.352024026215076
rbf_plag_ctarget_5 45.75266968458891
rbf_gram_lag_5 45.922999270260334
rbf_spd_xt4_5 49.51573442667723
rbf_spd_xt_5 49.51573442667723
rbf_100_5 49.753179773688316
rbf_dual_5 49.753179773688316
rbf_spd_5 49.8027540743351
rbf_spd_ptm_cxt_5 50.30657630413771
rbf_ptarget_clag_5 50.80373026430607
rbf_spd_clag_tm_5 51.53519567102194
rbf_x0_5 58.08610934764147
rbf_mix_optimal_5 72.29273673146963
rbf_mix_ecp_same_5 107.21026919782162
rbf_inception_lag_5 131.0362946242094
lagrange_mix_5 249.61229413747787
rbf_mix_5 288.63366693258286


In [5]:
for k, v in pairs:
    if k.endswith('_6'):
        print(k, v)


rbf_xt_6 27.419179677963257
rbf_spd_6 27.479410637170076
rbf_spd_ptm_cxt_6 27.595635037869215
rbf_spd_xt4_6 27.644920628517866
rbf_spd_xt_6 27.644920628517866
rbf_ecp_optimal_6 28.909866232424974
rbf_ecp_same4_6 28.953503351658583
rbf_ecp_same5_6 28.953503351658583
rbf_ecp_same_optimal_6 28.953503351658583
rbf_spd_const_6 29.753081034868956
rbf_gram_6 29.78525822982192
rbf_mix_optimal_6 30.219706241041422
rbf_gram_lag_6 31.074231956154108
rbf_const_optimal_6 31.661479733884335
rbf_100_6 31.916326843202114
rbf_dual_6 31.981354113668203
rbf_const_grid_optimal_6 32.74877555668354
rbf_ptarget_clag_6 32.843363005667925
rbf_x0_6 41.53312649577856
rbf_plag_ctarget_6 46.82152532041073
lagrange_6 48.88362716883421
rbf_mix_ecp_same_6 52.955253049731255
rbf_inception_lag_6 55.405725724995136
lagrange_mix_6 97.76095859706402
rbf_mix_6 139.80778865516186


In [6]:
for k, v in pairs:
    if k.endswith('_8'):
        print(k, v)


rbf_ecp_same4_8 14.12846613675356
rbf_ecp_same5_8 14.12846613675356
rbf_ecp_same_optimal_8 14.619318535551429
rbf_const_grid_optimal_8 14.689426170662045
rbf_100_8 14.793367590755224
rbf_spd_ptm_cxt_8 15.001602005213499
rbf_dual_8 15.237516490742564
rbf_spd_8 15.371313784271479
rbf_const_optimal_8 15.41562844067812
rbf_spd_const_8 15.849696937948465
rbf_ptarget_clag_8 16.59309258684516
lagrange_mix_8 16.822522738948464
rbf_plag_ctarget_8 17.860193038359284
rbf_spd_xt4_8 18.079575384035707
rbf_ecp_optimal_8 18.64523161202669
lagrange_8 19.424563506618142
rbf_xt_8 19.461405463516712
rbf_spd_xt_8 24.432663340121508
rbf_gram_8 25.49838274717331
rbf_x0_8 25.92105185613036
rbf_mix_8 26.705011259764433
rbf_gram_lag_8 27.962825261056423
rbf_mix_ecp_same_8 28.1087146140635
rbf_mix_optimal_8 28.496242593973875
rbf_inception_lag_8 44.014318846166134


In [7]:
for k, v in pairs:
    if k.endswith('_10'):
        print(k, v)


rbf_ecp_same5_10 7.755103288218379
lagrange_mix_10 7.891880813986063
rbf_ecp_same_optimal_10 8.437956566922367
rbf_100_10 8.562402799725533
rbf_ecp_same4_10 8.676119614392519
lagrange_10 8.812317973934114
rbf_const_optimal_10 9.179448243230581
rbf_dual_10 9.417074616067111
rbf_plag_ctarget_10 9.535954450257123
rbf_spd_const_10 9.61054116487503
rbf_ptarget_clag_10 9.847191395238042
rbf_spd_10 10.258876718580723
rbf_ecp_optimal_10 10.294857202097774
rbf_spd_xt4_10 10.521096410229802
rbf_xt_10 10.527967242524028
rbf_mix_10 10.955958859995008
rbf_spd_ptm_cxt_10 11.270045069977641
rbf_gram_10 11.446084827184677
rbf_const_grid_optimal_10 12.344027636572719
rbf_gram_lag_10 13.215860817581415
rbf_spd_xt_10 13.388418592512608
rbf_x0_10 14.777167234569788
rbf_inception_lag_10 20.662806928157806
rbf_mix_ecp_same_10 27.27283863350749
rbf_mix_optimal_10 30.673157889395952


In [8]:
for k, v in pairs:
    if k.endswith('_12'):
        print(k, v)


lagrange_mix_12 5.904371500946581
rbf_spd_xt_12 6.647897535003722
rbf_ecp_same5_12 6.647955160588026
rbf_ecp_same4_12 6.732998299412429
rbf_ecp_same_optimal_12 6.890277145430446
lagrange_12 6.892408127896488
rbf_spd_xt4_12 6.984025822021067
rbf_const_optimal_12 7.627116283401847
rbf_ecp_optimal_12 8.336942410096526
rbf_mix_12 8.899232489056885
rbf_dual_12 21.084139589220285
rbf_mix_ecp_same_12 26.559706311672926
rbf_mix_optimal_12 31.852393876761198


In [9]:
for k, v in pairs:
    if k.endswith('_15'):
        print(k, v)


lagrange_mix_15 3.1746135209687054
lagrange_15 3.4775087260641158
rbf_ecp_same_optimal_15 3.4808882628567517
rbf_spd_xt_15 3.7374728708527982
rbf_ecp_same4_15 3.7628604331985116
rbf_spd_xt4_15 3.8560497341677547
rbf_mix_15 4.43167460616678
rbf_ecp_same5_15 4.682581638917327
rbf_mix_optimal_15 29.88395979627967
