In [62]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torchmetrics.image.inception import InceptionScore
from modules import Generator

# ─── CONFIG ────────────────────────────────────────────────────────────────────
DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NZ            = 128           # latent dim for CIFAR-10 in your code
IMAGE_SIZE    = 32            # CIFAR-10 images trained at 32×32
NC            = 3             # # channels
NGF           = 64            # your generator base channels
BATCH_SIZE    = 100           # tune to fit your GPU memory
N_SAMPLES     = 50000         # number of samples for IS
SPLITS        = 10            # standard for IS
WEIGHTS_PATH  = "./mmdgan_experiment/cifar10/netG_50000.pth"
# ────────────────────────────────────────────────────────────────────────────────

def compute_inception_score_on_fake(netG, device,
                                    batch_size=BATCH_SIZE,
                                    N_samples=N_SAMPLES,
                                    splits=SPLITS):
    """
    Generates N_samples with netG(z) in memory,
    updates torchmetrics InceptionScore, and returns it.
    """
    netG.eval()
    is_metric = InceptionScore(splits=splits, normalize=True).to(device)

    # number of batches we'll do
    n_batches = N_samples // batch_size

    with torch.no_grad():
        for _ in tqdm(range(n_batches), desc="Sampling & scoring batches"):
            z = torch.randn(batch_size, NZ, 1, 1, device=device)
            fakes = netG(z)
            # if your netG outputs in [-1,1], rescale to [0,1]:
            fakes = (fakes + 1) / 2
            # InceptionScore will resize to 299×299 & normalize internally
            is_metric.update(fakes)

    mean_is, std_is = is_metric.compute()
    return mean_is.item(), std_is.item()


def main():
    # 1) Rebuild generator and load weights
    netG = Generator(IMAGE_SIZE, NC, NZ, NGF).to(DEVICE)
    state = torch.load(WEIGHTS_PATH, map_location=DEVICE)
    netG.load_state_dict(state)
    print(f"Loaded weights from {WEIGHTS_PATH}")

    # 2) Compute IS with progress bar
    mean_is, std_is = compute_inception_score_on_fake(
        netG, DEVICE,
        batch_size=BATCH_SIZE,
        N_samples=N_SAMPLES,
        splits=SPLITS
    )
    print(f"CIFAR-10 Inception Score ({N_SAMPLES} samples): "
          f"{mean_is:.3f} ± {std_is:.3f}")


if __name__ == "__main__":
    main()


  state = torch.load(WEIGHTS_PATH, map_location=DEVICE)


Loaded weights from ./mmdgan_experiment/cifar10/netG_50000.pth


Sampling & scoring batches: 100%|██████████| 500/500 [1:27:09<00:00, 10.46s/it]


CIFAR-10 Inception Score (50000 samples): 5.388 ± 0.048
