In [8]:
import os
import torch

root_dir = '/data/archive/sd-v1-4'
model_name = 'dpm_solver++'
steps = 200
scale = 1.5
path = os.path.join(root_dir, f"{model_name}_steps{steps}_scale{scale}_clip")

def get_features(pt_dir):
    pt_files = sorted([os.path.join(path, f) for f in os.listdir(path)])
    features = []
    for file in pt_files[:2000]:
        data = torch.load(file)
        image_features = data['image_features'].cpu()
        features.append(image_features)

    features = torch.cat(features)
    return features

ref_features = get_features(path)
ref_features.shape

torch.Size([10000, 512])

In [16]:
import os
import torch
import torch.nn.functional as F

root_dir = '/data/archive/sd-v1-4'
scales = [1.5, 3.5, 5.5, 9.5]
steps_list = [5, 6, 8, 10, 12, 15, 20]
model_names = ['dpm_solver++', 'uni_pc_bh2', 'dpm_solver_v3', 'rbf_order2', 'rbf_order3']

for scale in scales:  
    print(scale)  
    for model_name in model_names:
        cosims = []
        for steps in steps_list:
            path = os.path.join(root_dir, f"{model_name}_steps{steps}_scale{scale}_clip")
            if os.path.exists(path):
                comp_features = get_features(path)
                cosim = torch.mean(F.cosine_similarity(ref_features, comp_features)).item()
                cosims.append(f"{cosim:0.4f}")
        if len(cosims) > 0:
            print(model_name, cosims)
    print()

1.5
dpm_solver++ ['0.8579', '0.8906', '0.9263', '0.9453', '0.9570', '0.9688', '0.9800']
uni_pc_bh2 ['0.8789', '0.9092', '0.9404', '0.9575', '0.9688', '0.9790', '0.9878']
dpm_solver_v3 ['0.9116', '0.9316', '0.9561', '0.9702', '0.9800', '0.9873', '0.9937']
rbf_order2 ['0.8887', '0.9165', '0.9453', '0.9614', '0.9717', '0.9810', '0.9888']
rbf_order3 ['0.8926', '0.9199', '0.9468', '0.9629', '0.9731', '0.9814', '0.9893']

3.5
dpm_solver++ ['0.8291', '0.8369', '0.8418', '0.8428', '0.8428', '0.8428', '0.8423']
uni_pc_bh2 ['0.8340', '0.8389', '0.8408', '0.8413', '0.8413', '0.8408', '0.8408']
rbf_order2 ['0.8359', '0.8394', '0.8408', '0.8408', '0.8403', '0.8398', '0.8398']
rbf_order3 ['0.8364', '0.8394', '0.8403', '0.8403', '0.8398', '0.8389', '0.8394']

5.5
dpm_solver++ ['0.7983', '0.8027', '0.8052', '0.8057', '0.8057', '0.8052', '0.8052']
uni_pc_bh2 ['0.8003', '0.8032', '0.8047', '0.8047', '0.8047', '0.8037', '0.8042']
rbf_order2 ['0.8008', '0.8037', '0.8047', '0.8042', '0.8042', '0.8037', '0.