In [11]:
import numpy as np
import matplotlib.pyplot as plt

def show_image_grid(images):
    """
    images: np.array of shape (256, 32, 32, 3)
    """
    assert len(images) == 256, "이미지는 총 256장이어야 합니다."
    n = 16  # 16x16 그리드를 위해
    
    fig, axes = plt.subplots(n, n, figsize=(8, 8))
    # figsize는 상황에 따라 적절히 조정 가능

    # 각 grid 셀에 이미지 하나씩 배치
    for idx, ax in enumerate(axes.flat):
        ax.imshow(images[idx])
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()


euler_npz = '/data/score_sde_outputs/checkpoint_8/euler_1000/samples_0.npz'
data = np.load(euler_npz)
references = data['samples'].astype(np.float32)
print(references.shape)
#show_image_grid(references)

(256, 32, 32, 3)


In [12]:
import os
import numpy as np

base_dir = "../score_sde_outputs/checkpoint_8"

samples_list = []  # 불러온 samples를 저장할 리스트(또는 다른 방법으로 저장)
results = []       # (디렉토리 경로, diff) 튜플을 저장할 리스트

for root, dirs, files in os.walk(base_dir):
    if 'samples_0.npz' in files:
        file_path = os.path.join(root, 'samples_0.npz')
        data = np.load(file_path)
        samples = data['samples'].astype(np.float32)
        
        # references가 이미 정의되어 있다고 가정
        diff = np.mean(np.abs(references - samples))
        
        # diff 값을 저장 (디렉토리 경로, diff)
        results.append((root, diff))
        
        # samples_list에 (디렉토리 경로, samples) 형태로 저장
        samples_list.append((root, samples))
        
# 정렬된 결과 출력
results.sort(key=lambda x: x[0])
for root, diff_value in results:
    print(root, diff_value)


../score_sde_outputs/checkpoint_8/euler_10 11.878929
../score_sde_outputs/checkpoint_8/euler_1000 0.0
../score_sde_outputs/checkpoint_8/euler_20 5.293545
../score_sde_outputs/checkpoint_8/euler_5 23.688494
../score_sde_outputs/checkpoint_8/lagrange_10 4.4752464
../score_sde_outputs/checkpoint_8/lagrange_12 3.9208477
../score_sde_outputs/checkpoint_8/lagrange_15 2.2450092
../score_sde_outputs/checkpoint_8/lagrange_20 1.1959254
../score_sde_outputs/checkpoint_8/lagrange_25 0.68584186
../score_sde_outputs/checkpoint_8/lagrange_30 0.38820902
../score_sde_outputs/checkpoint_8/lagrange_5 11.761868
../score_sde_outputs/checkpoint_8/lagrange_6 13.616203
../score_sde_outputs/checkpoint_8/lagrange_8 8.223865
../score_sde_outputs/checkpoint_8/lagrange_mix_10 3.8764675
../score_sde_outputs/checkpoint_8/lagrange_mix_12 3.369564
../score_sde_outputs/checkpoint_8/lagrange_mix_15 1.9351107
../score_sde_outputs/checkpoint_8/lagrange_mix_20 1.0224546
../score_sde_outputs/checkpoint_8/lagrange_mix_25 0.5

In [43]:
# diff 기준으로 오름차순 정렬
results.sort(key=lambda x: x[1])

# 정렬된 결과 출력
for root, diff_value in results:
    print(root, diff_value)


../score_sde_outputs/checkpoint_8/euler_1000 0.0
../score_sde_outputs/checkpoint_8/lagrange_mix_25 0.59590274
../score_sde_outputs/checkpoint_8/lagrange_mix_20 1.0224546
../score_sde_outputs/checkpoint_8/rbf_unipc_40 1.2456526
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_40 1.2460403
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_35 1.3493181
../score_sde_outputs/checkpoint_8/rbf_unipc_35 1.3507818
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_30 1.4597219
../score_sde_outputs/checkpoint_8/rbf_unipc_30 1.4610735
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_25 1.9045919
../score_sde_outputs/checkpoint_8/rbf_unipc_25 1.9046911
../score_sde_outputs/checkpoint_8/lagrange_mix_15 1.9351107
../score_sde_outputs/checkpoint_8/rbf_mix_25 2.3251152
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_20 2.4121196
../score_sde_outputs/checkpoint_8/rbf_unipc_20 2.412608
../score_sde_outputs/checkpoint_8/rbf_mix_20 3.0182254
../score_sde_outputs/checkpoint_8/lagrange_mix_12 3.369564
../score_sde

In [35]:
# diff 기준으로 오름차순 정렬
results.sort(key=lambda x: x[1])

# 정렬된 결과 출력
for root, diff_value in results:
    if '_5' in root:
        print(root, diff_value)


../score_sde_outputs/checkpoint_8/rbf_const_optimal_5 10.598403
../score_sde_outputs/checkpoint_8/rbf_gram_5 10.702218
../score_sde_outputs/checkpoint_8/rbf_gram_lag_5 10.769958
../score_sde_outputs/checkpoint_8/rbf_ecp_optimal_5 10.860032
../score_sde_outputs/checkpoint_8/rbf_ecp_same_optimal_5 10.870486
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_5 10.887895
../score_sde_outputs/checkpoint_8/rbf_const_grid_optimal_5 11.249458
../score_sde_outputs/checkpoint_8/rbf_dual_5 11.718282
../score_sde_outputs/checkpoint_8/rbf_100_5 11.718299
../score_sde_outputs/checkpoint_8/lagrange_5 11.761868
../score_sde_outputs/checkpoint_8/rbf_ptarget_clag_5 11.873136
../score_sde_outputs/checkpoint_8/rbf_plag_ctarget_5 11.881045
../score_sde_outputs/checkpoint_8/rbf_mix_optimal_5 15.206475
../score_sde_outputs/checkpoint_8/rbf_inception_lag_5 21.68321
../score_sde_outputs/checkpoint_8/rbf_unipc_5 22.74916
../score_sde_outputs/checkpoint_8/lagrange_mix_5 30.612417
../score_sde_outputs/checkpoint_8/r

In [36]:
# 정렬된 결과 출력
for root, diff_value in results:
    if '_6' in root:
        print(root, diff_value)


../score_sde_outputs/checkpoint_8/rbf_gram_6 8.299976
../score_sde_outputs/checkpoint_8/rbf_ecp_optimal_6 8.466301
../score_sde_outputs/checkpoint_8/rbf_gram_lag_6 8.547057
../score_sde_outputs/checkpoint_8/rbf_ecp_same_optimal_6 8.737594
../score_sde_outputs/checkpoint_8/rbf_mix_optimal_6 8.785775
../score_sde_outputs/checkpoint_8/rbf_dual_6 9.157387
../score_sde_outputs/checkpoint_8/rbf_100_6 9.160007
../score_sde_outputs/checkpoint_8/rbf_ptarget_clag_6 9.302528
../score_sde_outputs/checkpoint_8/rbf_const_grid_optimal_6 10.065587
../score_sde_outputs/checkpoint_8/rbf_const_optimal_6 10.453789
../score_sde_outputs/checkpoint_8/rbf_inception_lag_6 12.789939
../score_sde_outputs/checkpoint_8/rbf_plag_ctarget_6 13.420246
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_6 13.589089
../score_sde_outputs/checkpoint_8/rbf_unipc_6 18.627811
../score_sde_outputs/checkpoint_8/lagrange_mix_6 19.855972
../score_sde_outputs/checkpoint_8/rbf_mix_6 23.812675


In [40]:
# 정렬된 결과 출력
for root, diff_value in results:
    if root.endswith('_8'):
        print(root, diff_value)


../score_sde_outputs/checkpoint_8/rbf_100_8 5.6332145
../score_sde_outputs/checkpoint_8/rbf_dual_8 5.862551
../score_sde_outputs/checkpoint_8/rbf_ecp_same_optimal_8 5.9372573
../score_sde_outputs/checkpoint_8/rbf_const_grid_optimal_8 6.125365
../score_sde_outputs/checkpoint_8/rbf_ptarget_clag_8 6.2273345
../score_sde_outputs/checkpoint_8/rbf_const_optimal_8 6.4110336
../score_sde_outputs/checkpoint_8/lagrange_mix_8 6.838885
../score_sde_outputs/checkpoint_8/rbf_ecp_optimal_8 6.9645057
../score_sde_outputs/checkpoint_8/rbf_plag_ctarget_8 7.635797
../score_sde_outputs/checkpoint_8/rbf_gram_8 8.382976
../score_sde_outputs/checkpoint_8/rbf_mix_optimal_8 8.635074
../score_sde_outputs/checkpoint_8/rbf_gram_lag_8 8.875142
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_8 9.51718
../score_sde_outputs/checkpoint_8/rbf_unipc_8 9.8616905
../score_sde_outputs/checkpoint_8/rbf_mix_8 9.925122
../score_sde_outputs/checkpoint_8/rbf_inception_lag_8 11.598744


In [37]:
# 정렬된 결과 출력
for root, diff_value in results:
    if root.endswith('_10'):
        print(root, diff_value)


../score_sde_outputs/checkpoint_8/euler_10 11.878929
../score_sde_outputs/checkpoint_8/lagrange_10 4.4752464
../score_sde_outputs/checkpoint_8/lagrange_mix_10 3.8764675
../score_sde_outputs/checkpoint_8/rbf_100_10 4.6050544
../score_sde_outputs/checkpoint_8/rbf_const_grid_optimal_10 4.953646
../score_sde_outputs/checkpoint_8/rbf_const_optimal_10 4.0616746
../score_sde_outputs/checkpoint_8/rbf_dual_10 4.496086
../score_sde_outputs/checkpoint_8/rbf_ecp_optimal_10 4.6050735
../score_sde_outputs/checkpoint_8/rbf_ecp_same_optimal_10 4.3172927
../score_sde_outputs/checkpoint_8/rbf_gram_10 5.0957246
../score_sde_outputs/checkpoint_8/rbf_gram_lag_10 5.5582986
../score_sde_outputs/checkpoint_8/rbf_inception_lag_10 7.2289176
../score_sde_outputs/checkpoint_8/rbf_mix_10 6.007257
../score_sde_outputs/checkpoint_8/rbf_mix_optimal_10 9.128291
../score_sde_outputs/checkpoint_8/rbf_plag_ctarget_10 4.968899
../score_sde_outputs/checkpoint_8/rbf_ptarget_clag_10 4.5298753
../score_sde_outputs/checkpoint_

In [44]:
# 정렬된 결과 출력
for root, diff_value in results:
    if root.endswith('_25'):
        print(root, diff_value)


../score_sde_outputs/checkpoint_8/lagrange_mix_25 0.59590274
../score_sde_outputs/checkpoint_8/rbf_unipc_bh2_25 1.9045919
../score_sde_outputs/checkpoint_8/rbf_unipc_25 1.9046911
../score_sde_outputs/checkpoint_8/rbf_mix_25 2.3251152


In [36]:
np.exp(np.linspace(-2, 2, 100))

array([0.13533528, 0.14091534, 0.14672548, 0.15277518, 0.15907431,
       0.16563316, 0.17246245, 0.17957331, 0.18697737, 0.19468671,
       0.20271391, 0.21107209, 0.21977488, 0.22883651, 0.23827175,
       0.24809603, 0.25832537, 0.26897649, 0.28006676, 0.2916143 ,
       0.30363796, 0.31615738, 0.32919299, 0.34276607, 0.35689879,
       0.37161423, 0.3869364 , 0.40289032, 0.41950205, 0.4367987 ,
       0.45480852, 0.4735609 , 0.49308648, 0.51341712, 0.53458602,
       0.55662774, 0.57957828, 0.6034751 , 0.62835721, 0.65426525,
       0.68124152, 0.70933005, 0.73857671, 0.76902926, 0.8007374 ,
       0.83375292, 0.86812971, 0.9039239 , 0.94119394, 0.98000067,
       1.02040746, 1.06248028, 1.10628782, 1.1519016 , 1.1993961 ,
       1.24884887, 1.30034064, 1.35395549, 1.40978096, 1.46790818,
       1.52843208, 1.59145146, 1.65706921, 1.72539247, 1.7965328 ,
       1.87060634, 1.94773404, 2.02804182, 2.1116608 , 2.19872751,
       2.2893841 , 2.38377858, 2.48206508, 2.58440408, 2.69096