In [8]:
import os
import torch
import torch.nn.functional as F

def get_features(pt_dir):
    pt_files = sorted([os.path.join(pt_dir, f) for f in os.listdir(pt_dir)])
    features = []
    for file in pt_files[:2000]:
        data = torch.load(file)
        image_features = data['image_features'].cpu()
        features.append(image_features)

    features = torch.cat(features)
    return features

root_dir = '/data/archive/sd-v1-4'
scales = [1.5, 3.5, 5.5, 9.5]
steps_list = [5, 6, 8, 10, 12, 15, 20]
model_names = ['dpm_solver++', 'uni_pc_bh2', 'dpm_solver_v3', 'rbf_order2', 'rbf_order3']

for scale in scales:  
    print(scale)
    path = os.path.join(root_dir, f"dpm_solver++_steps200_scale{scale}_clip")
    ref_features = get_features(path)

    for model_name in model_names:
        cosims = []
        for steps in steps_list:
            path = os.path.join(root_dir, f"{model_name}_steps{steps}_scale{scale}_clip")
            if os.path.exists(path):
                comp_features = get_features(path)
                cosim = torch.mean(F.cosine_similarity(ref_features, comp_features)).item()
                cosims.append(f"{cosim:0.4f}")
        if len(cosims) > 0:
            print(model_name, cosims)
    print()


1.5
dpm_solver++ ['0.8579', '0.8906', '0.9263', '0.9453', '0.9570', '0.9688', '0.9800']
uni_pc_bh2 ['0.8789', '0.9092', '0.9404', '0.9575', '0.9688', '0.9790', '0.9878']
dpm_solver_v3 ['0.9116', '0.9316', '0.9561', '0.9702', '0.9800', '0.9873', '0.9937']
rbf_order2 ['0.8887', '0.9165', '0.9453', '0.9614', '0.9717', '0.9810', '0.9888']
rbf_order3 ['0.8926', '0.9199', '0.9468', '0.9629', '0.9731', '0.9814', '0.9893']

3.5
dpm_solver++ ['0.8774', '0.8989', '0.9233', '0.9380', '0.9478', '0.9585', '0.9702']
uni_pc_bh2 ['0.8877', '0.9067', '0.9287', '0.9434', '0.9536', '0.9648', '0.9771']
rbf_order2 ['0.8916', '0.9097', '0.9302', '0.9448', '0.9551', '0.9663', '0.9785']
rbf_order3 ['0.8931', '0.9106', '0.9302', '0.9443', '0.9546', '0.9663', '0.9780']

5.5
dpm_solver++ ['0.8794', '0.8960', '0.9170', '0.9297', '0.9399', '0.9502', '0.9619']
uni_pc_bh2 ['0.8843', '0.8994', '0.9189', '0.9316', '0.9419', '0.9526', '0.9658']
rbf_order2 ['0.8853', '0.9004', '0.9194', '0.9321', '0.9424', '0.9531', '0.

In [7]:
!ls /data/archive/sd-v1-4

dpm_solver++_steps10_scale1.5	     rbf_order2_steps6_scale5.5
dpm_solver++_steps10_scale1.5_clip   rbf_order2_steps6_scale5.5_clip
dpm_solver++_steps10_scale1.5_fid    rbf_order2_steps6_scale9.5
dpm_solver++_steps10_scale3.5	     rbf_order2_steps6_scale9.5_clip
dpm_solver++_steps10_scale3.5_clip   rbf_order2_steps8_scale1.5
dpm_solver++_steps10_scale5.5	     rbf_order2_steps8_scale1.5_clip
dpm_solver++_steps10_scale5.5_clip   rbf_order2_steps8_scale3.5
dpm_solver++_steps10_scale9.5	     rbf_order2_steps8_scale3.5_clip
dpm_solver++_steps10_scale9.5_clip   rbf_order2_steps8_scale5.5
dpm_solver++_steps12_scale1.5	     rbf_order2_steps8_scale5.5_clip
dpm_solver++_steps12_scale1.5_clip   rbf_order2_steps8_scale9.5
dpm_solver++_steps12_scale3.5	     rbf_order2_steps8_scale9.5_clip
dpm_solver++_steps12_scale3.5_clip   rbf_order3_steps10_scale1.5
dpm_solver++_steps12_scale5.5	     rbf_order3_steps10_scale1.5_clip
dpm_solver++_steps12_scale5.5_clip   rbf_order3_steps10_scale3.5
dpm_solver++_ste

In [None]:
!ls /dat