#### This notebooks works with the code from the edm paper: https://github.com/NVlabs/edm


In [2]:
import tqdm
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib
import matplotlib.pyplot as plt

from torch_utils import distributed as dist
from torch_utils import misc

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torchvision

Dataset MNIST
    Number of datapoints: 60000
    Root location: datasets
    Split: Train

In [16]:
# make an 8x8 picture of mnist digits for visual comparison with the samples
mnist_dataset = torchvision.datasets.MNIST('datasets', download=True, train=True, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Resize((32,32))]) )
dataloader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=64)

image = iter(dataloader).__next__()[0]

image = (image * 255).clip(0, 255).to(torch.uint8)
image = image.reshape(8, 8, *image.shape[1:]).permute(0, 3, 1, 4, 2)
image = image.reshape(
    8 * 32, 8 * 32, 1
)
image = image.cpu().numpy()
print(image.shape)
PIL.Image.fromarray(image.squeeze(), "L").save('mnist-32-32-originals.png')

(256, 256, 1)


In [14]:
device=torch.device('cuda')

c = dnnlib.EasyDict()
c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path='datasets/mnist-32x32.zip', use_labels=False, xflip=False, cache=True)
c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=2, prefetch_factor=2)
seed = 0
batch_size           = 64     # Limit batch size per GPU, None = no limit.
batch_gpu = 64

# Load dataset.
dist.print0('Loading dataset...')
dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs) # subclass of training.dataset.Dataset
dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=64, **c.data_loader_kwargs))

Loading dataset...


In [13]:
dataset_obj.num_channels, dataset_obj.resolution

(1, 32)

In [24]:
c.network_kwargs = dnnlib.EasyDict()
c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=32, channel_mult=[2,2,2])
c.network_kwargs.update(dropout=1, use_fp16=0)
c.network_kwargs.class_name = 'training.networks.EDMPrecond'

# Construct network.
dist.print0('Constructing network...')
interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
net = dnnlib.util.construct_class_by_name(**c.network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
net.train().requires_grad_(True).to(device)
if dist.get_rank() == 0:
    with torch.no_grad():
        images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
        sigma = torch.ones([batch_gpu], device=device)
        labels = torch.zeros([batch_gpu, net.label_dim], device=device)
        misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)

Constructing network...

EDMPrecond                Parameters  Buffers  Output shape      Datatype
---                       ---         ---      ---               ---     
model.map_noise           -           -        [64, 32]          float32 
model.map_layer0          4224        -        [64, 128]         float32 
model.map_layer1          16512       -        [64, 128]         float32 
model.enc.32x32_conv      320         -        [64, 32, 32, 32]  float32 
model.enc.32x32_block0    65984       -        [64, 64, 32, 32]  float32 
model.enc.32x32_block1    82368       -        [64, 64, 32, 32]  float32 
model.enc.32x32_block2    82368       -        [64, 64, 32, 32]  float32 
model.enc.32x32_block3    82368       -        [64, 64, 32, 32]  float32 
model.enc.16x16_down      86528       8        [64, 64, 16, 16]  float32 
model.enc.16x16_block0    99136       -        [64, 64, 16, 16]  float32 
model.enc.16x16_block1    99136       -        [64, 64, 16, 16]  float32 
model.enc.16x

In [20]:
net

EDMPrecond(
  (model): SongUNet(
    (map_noise): PositionalEmbedding()
    (map_layer0): Linear()
    (map_layer1): Linear()
    (enc): ModuleDict(
      (32x32_conv): Conv2d()
      (32x32_block0): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
        (skip): Conv2d()
      )
      (32x32_block1): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
      )
      (32x32_block2): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
      )
      (32x32_block3): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
      )
      (16x16_down): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()