In [None]:
import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision import utils as vutils

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
dataset = mmc.datasets.load_dataset(root='./sngan/datasets', name='cifar10')
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=64,
                                         shuffle=True,
                                         num_workers=4)

# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

# Set up trainer
trainer = mmc.training.Trainer(netD=netD,
                               netG=netG,
                               optD=optD,
                               optG=optG,
                               n_dis=5,
                               num_steps=100000,
                               lr_decay='linear',
                               dataloader=dataloader,
                               log_dir='./sngan-reg/log/cifar10',
                               device=device
                               )

In [None]:
# Start training
trainer.train()

In [None]:
# Calculate FID score
mmc.metrics.evaluate(metric='fid',
                     log_dir='./sngan-reg/log/cifar10',
                     netG=netG,
                     dataset='cifar10',
                     num_real_samples=1000,
                     num_fake_samples=1000,
                     evaluate_step=100000,
                     num_runs=1,
                     device=device)

In [None]:
# Generate images
logger = mmc.training.Logger(
    log_dir = './sngan-reg/log/cifar10',
    num_steps = 100000,
    dataset_size = 60000,
    device = device
)

logger.vis_images(
    netG = netG, 
    global_step = 100000, 
    num_images=64)