In [1]:
import os
os.chdir('../')

In [4]:
import os
import json
from easydict import EasyDict

import torch
import numpy as np
import torch.nn.functional as F


# Load a config.json file and return it as an EasyDict
def load_config(path):
    with open(os.path.join(path, 'config.json'), 'r') as f:
        return EasyDict(json.load(f))

# Check if two config objects match for all specified keys
def match(c1, c2, keys):
    return all(c1[k] == c2[k] for k in keys)

def get_rmse(dir1, dir2):
    rmses = []
    i = 0
    while True:
        file1 = os.path.join(dir1, f"{i}.pt")
        file2 = os.path.join(dir2, f"{i}.pt")
        if not os.path.exists(file1) or not os.path.exists(file2):
            break
        data1 = torch.load(file1, weights_only=True)
        data2 = torch.load(file2, weights_only=True)
        rmse = torch.sqrt(F.mse_loss(data1, data2))
        rmses.append(rmse.item())
        i += 1
    rmse = np.mean(rmses).item()
    return rmse



In [6]:
ref_root = 'samplings/sana/ref'
sam_root = 'samplings/sana/sam'
# Attributes to compare between reference and sample configs
attrs = ['model', 'algorithm_type', 'skip_type', 'flow_shift', 'CFG', 'n_samples']

# Load reference configs
ref_configs = [load_config(os.path.join(ref_root, d)) for d in os.listdir(ref_root)]

# Load sample configs
sam_configs = [load_config(os.path.join(sam_root, d)) for d in os.listdir(sam_root)]

# Build a dictionary mapping each reference config's save_dir
# to the list of sample config save_dirs that match the reference config
config_dict = {
    ref.save_dir: [
        sam.save_dir for sam in sam_configs if match(ref, sam, attrs)
    ]
    for ref in ref_configs
}

for ref_dir in config_dict:
    for sam_dir in config_dict[ref_dir]:
        rmse = get_rmse(ref_dir, sam_dir)
        print(ref_dir, sam_dir, f"{rmse:0.2f}")
    print()


samplings/sana/ref/sanaref_1 samplings/sana/sam/sana_37 0.28
samplings/sana/ref/sanaref_1 samplings/sana/sam/sana_13 0.34
samplings/sana/ref/sanaref_1 samplings/sana/sam/sana_7 0.41
samplings/sana/ref/sanaref_1 samplings/sana/sam/sana_19 0.30
samplings/sana/ref/sanaref_1 samplings/sana/sam/sana_1 0.46
samplings/sana/ref/sanaref_1 samplings/sana/sam/sana_25 0.41
samplings/sana/ref/sanaref_1 samplings/sana/sam/sana_31 0.35

samplings/sana/ref/sanaref_4 samplings/sana/sam/sana_4 1.09
samplings/sana/ref/sanaref_4 samplings/sana/sam/sana_10 1.09
samplings/sana/ref/sanaref_4 samplings/sana/sam/sana_34 1.08
samplings/sana/ref/sanaref_4 samplings/sana/sam/sana_28 1.09
samplings/sana/ref/sanaref_4 samplings/sana/sam/sana_16 1.07
samplings/sana/ref/sanaref_4 samplings/sana/sam/sana_40 1.07
samplings/sana/ref/sanaref_4 samplings/sana/sam/sana_22 1.05

samplings/sana/ref/sanaref_5 samplings/sana/sam/sana_17 0.95
samplings/sana/ref/sanaref_5 samplings/sana/sam/sana_5 1.07
samplings/sana/ref/sanaref

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
