In [1]:
import os
os.chdir('../../')

### Reference Files

In [2]:
import os

ref_dirs = {}
root_dir = 'samplings'
for cfg in [1.5, 3.5, 5.5, 7.5, 9.5]:    
    dir = f'SANA(MSCOCO2017)(Euler)(data_prediction)(time_uniform_flow)(FS1.0)(NFE200)(CFG{cfg})(ORDER2)'
    ref_dirs[cfg] = str(os.path.join(root_dir, dir))

ref_dirs

{1.5: 'samplings/SANA(MSCOCO2017)(Euler)(data_prediction)(time_uniform_flow)(FS1.0)(NFE200)(CFG1.5)(ORDER2)',
 3.5: 'samplings/SANA(MSCOCO2017)(Euler)(data_prediction)(time_uniform_flow)(FS1.0)(NFE200)(CFG3.5)(ORDER2)',
 5.5: 'samplings/SANA(MSCOCO2017)(Euler)(data_prediction)(time_uniform_flow)(FS1.0)(NFE200)(CFG5.5)(ORDER2)',
 7.5: 'samplings/SANA(MSCOCO2017)(Euler)(data_prediction)(time_uniform_flow)(FS1.0)(NFE200)(CFG7.5)(ORDER2)',
 9.5: 'samplings/SANA(MSCOCO2017)(Euler)(data_prediction)(time_uniform_flow)(FS1.0)(NFE200)(CFG9.5)(ORDER2)'}

### Comparison Files

In [3]:
import torch
import numpy as np
import torch.nn.functional as F

def get_rmse(dir1, dir2):
    rmses = []
    for i in range(1000):
        comp_file = os.path.join(dir1, f"{i}.pt")
        ref_file = os.path.join(dir2, f"{i}.pt")
        comp_data = torch.load(comp_file, weights_only=True)
        ref_data = torch.load(ref_file, weights_only=True)
        rmse = torch.sqrt(F.mse_loss(comp_data, ref_data))
        rmses.append(rmse.item())
    rmse = np.mean(rmses).item()
    return rmse


In [5]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

save_dir = 'reports/sana/euler_dpm_unipc'
os.makedirs(save_dir, exist_ok=True)

for solver in ['Euler', 'DPM-Solver', 'UniPC']:
    for cfg in [1.5, 3.5, 5.5, 7.5, 9.5]:
        for NFE in [5, 6, 8, 10]:
            print(solver, cfg, NFE)
            dir = f'SANA(MSCOCO2017)({solver})(data_prediction)(time_uniform_flow)(FS1.0)(NFE{NFE})(CFG{cfg})(ORDER2)'
            comp_dir = os.path.join(root_dir, dir)
            ref_dir = ref_dirs[cfg]
            rmse = get_rmse(comp_dir, ref_dir)
            save_file = os.path.join(save_dir, f'{solver}_{cfg}_{NFE}_FS1.0.txt')
            with open(save_file, 'w') as f:
                f.write(f"{rmse}")

print('done')

Euler 1.5 5
Euler 1.5 6
Euler 1.5 8
Euler 1.5 10
Euler 3.5 5
Euler 3.5 6
Euler 3.5 8
Euler 3.5 10
Euler 5.5 5
Euler 5.5 6
Euler 5.5 8
Euler 5.5 10
Euler 7.5 5
Euler 7.5 6
Euler 7.5 8
Euler 7.5 10
Euler 9.5 5


KeyboardInterrupt: 

In [None]:
print("done")

done
