In [8]:
import os, torch, numpy as np
from tqdm import tqdm

def calc_fid_statistics(pt_dir: str, key: str = "inception_feats"):
    """
    pt_dir 안의 *.pt 파일에서 특징 벡터를 모아 평균(μ)·공분산(Σ)을 계산해 반환.
    반환값
      μ: (2048,)   torch.Tensor (float32)
      Σ: (2048,2048)  torch.Tensor (float32)
    """
    files = [os.path.join(pt_dir, f) for f in os.listdir(pt_dir) if f.endswith(".pt")]
    if not files:
        raise FileNotFoundError(f"No .pt files in {pt_dir}")

    feats = []
    for f in tqdm(files, desc="Loading features"):
        data = torch.load(f, map_location="cpu")[key].float()   # (Nᵢ,2048)
        feats.append(data)
    feats = torch.cat(feats, dim=0)                             # (N_total,2048)

    mu = feats.mean(dim=0)                                      # (2048,)
    sigma = torch.cov(feats.T)                                  # (2048,2048)
    return mu, sigma

mu, sigma = calc_fid_statistics('/data/archive/sd-v1-4/dpm_solver++_steps200_scale1.5_fid')
mu, sigma

Loading features: 100%|██████████| 2006/2006 [00:00<00:00, 13749.95it/s]


(tensor([0.3827, 0.3682, 0.3793,  ..., 0.3910, 0.3928, 0.3497]),
 tensor([[ 0.1114,  0.0109,  0.0161,  ...,  0.0047, -0.0055,  0.0012],
         [ 0.0109,  0.0798, -0.0061,  ..., -0.0091, -0.0051,  0.0043],
         [ 0.0161, -0.0061,  0.1206,  ..., -0.0036,  0.0170,  0.0121],
         ...,
         [ 0.0047, -0.0091, -0.0036,  ...,  0.1493,  0.0297, -0.0035],
         [-0.0055, -0.0051,  0.0170,  ...,  0.0297,  0.1531, -0.0046],
         [ 0.0012,  0.0043,  0.0121,  ..., -0.0035, -0.0046,  0.1085]]))