In [3]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchmetrics.image.fid import FrechetInceptionDistance
from modules import Generator
from tqdm import tqdm

# ─── CONFIG ────────────────────────────────────────────────────────────────────
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NZ            = 128           
IMAGE_SIZE    = 32            
NC            = 3             
NGF           = 64            
BATCH_SIZE    = 100          
N_SAMPLES     = 50000        
WEIGHTS_PATH  = "./mmdgan_experiment/cifar10/netG_50000.pth"
# ────────────────────────────────────────────────────────────────────────────────

def sanity_main():
    # 1) Build G and load
    netG = Generator(IMAGE_SIZE, NC, NZ, NGF).to(DEVICE)
    state = torch.load(WEIGHTS_PATH, map_location=DEVICE)
    netG.load_state_dict(state)
    netG.eval()
    print(f"Loaded G weights from {WEIGHTS_PATH}")

    # 2) Prepare FID metric (it will normalize internally)
    fid = FrechetInceptionDistance(normalize=True).to(DEVICE)

    # 3) Real‐data loader (no Normalize here!)
    transform_real = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),            # yields [0,1]
    ])
    real_ds = datasets.CIFAR10(
        root="./data", train=True, download=True,
        transform=transform_real
    )
    real_loader = DataLoader(
        real_ds, batch_size=BATCH_SIZE,
        shuffle=False, num_workers=4, pin_memory=True
    )

       # ── SANITY CHECK: REAL vs REAL ────────────────────────────────────
    fid.reset()
    # First pass as “real”
    for imgs, _ in real_loader:
        fid.update(imgs.to(DEVICE), real=True)
    # Second pass as “fake”
    for imgs, _ in real_loader:
        fid.update(imgs.to(DEVICE), real=False)
    print(f"Sanity FID (real vs real): {fid.compute():.3f}")    
    # Then reset before the real-vs-fake run
    fid.reset()
    # ───────────────────────────────────────────────────────────────────

sanity_main()


  state = torch.load(WEIGHTS_PATH, map_location=DEVICE)


Loaded G weights from ./mmdgan_experiment/cifar10/netG_50000.pth
Files already downloaded and verified
Sanity FID (real vs real): -0.000


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchmetrics.image.fid import FrechetInceptionDistance
from modules import Generator
from tqdm import tqdm

# ─── CONFIG ────────────────────────────────────────────────────────────────────
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NZ            = 128           
IMAGE_SIZE    = 32            
NC            = 3             
NGF           = 64            
BATCH_SIZE    = 100          
N_SAMPLES     = 50000        
WEIGHTS_PATH  = "./mmdgan_experiment/cifar10/netG_50000.pth"
# ────────────────────────────────────────────────────────────────────────────────

def main():
    # 1) Build G and load
    netG = Generator(IMAGE_SIZE, NC, NZ, NGF).to(DEVICE)
    state = torch.load(WEIGHTS_PATH, map_location=DEVICE)
    netG.load_state_dict(state)
    netG.eval()
    print(f"Loaded G weights from {WEIGHTS_PATH}")

    # 2) Prepare FID metric (it will normalize internally)
    fid = FrechetInceptionDistance(normalize=True).to(DEVICE)

    # 3) Real‐data loader (no Normalize here!)
    transform_real = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),            # yields [0,1]
    ])
    real_ds = datasets.CIFAR10(
        root="./data", train=True, download=True,
        transform=transform_real
    )
    real_loader = DataLoader(
        real_ds, batch_size=BATCH_SIZE,
        shuffle=False, num_workers=4, pin_memory=True
    )

    # 4) Accumulate real stats
    with torch.no_grad():
        for imgs, _ in tqdm(real_loader, desc="Real CIFAR-10"):
            imgs = imgs.to(DEVICE)
            fid.update(imgs, real=True)

    # 5) Generate & accumulate fake stats
    n_batches = N_SAMPLES // BATCH_SIZE
    with torch.no_grad():
        for _ in tqdm(range(n_batches), desc="Fake samples"):
            z = torch.randn(BATCH_SIZE, NZ, 1, 1, device=DEVICE)
            fakes = netG(z)
            fakes = (fakes + 1) / 2  # assume netG outputs [-1,1] → [0,1]
            fakes = F.interpolate(
                fakes, size=(299, 299),
                mode="bilinear", align_corners=False
            )
            fid.update(fakes, real=False)

    # 6) Compute & print
    fid_value = fid.compute()
    print(f"FID OF CIFAR-10: {fid_value:.3f}")

if __name__ == "__main__":
    main()


  state = torch.load(WEIGHTS_PATH, map_location=DEVICE)


Loaded G weights from ./mmdgan_experiment/cifar10/netG_50000.pth
Files already downloaded and verified


Real CIFAR-10: 100%|██████████| 500/500 [1:23:01<00:00,  9.96s/it]
Fake samples: 100%|██████████| 500/500 [1:22:18<00:00,  9.88s/it]


Corrected FID: 55.837
