In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
import yaml
from pathlib import Path
import seaborn as sns
from matplotlib import pyplot as plt

##
import sys
sys.path.append("../studiogan")
import studiogan
from maxent_gan.models.studiogans import StudioDis, StudioGen
##

from maxent_gan.utils.general_utils import DotConfig, ROOT_DIR, random_seed
from maxent_gan.models.utils import GANWrapper
from maxent_gan.distribution import Distribution, DiscriminatorTarget
from maxent_gan.mcmc import MCMCRegistry, ula, mala
from maxent_gan.train.loss import LossRegistry


mcmc = MCMCRegistry()

batch_size = 32

2022-07-11 15:10:56.412223: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


In [26]:
dataset = 'cifar10'
model = 'dcgan'

gan_config = yaml.safe_load(Path(ROOT_DIR, f'configs/gan_configs/cifar-10-{model}.yml').open('r'))
gan_config = DotConfig(gan_config['gan_config'])

device = torch.device(0 if torch.cuda.is_available() else 'cpu')

gan = GANWrapper(gan_config, device, eval=False)
criterion = LossRegistry.create('JensenNSLoss')
ref_dist = DiscriminatorTarget(gan, batch_size=batch_size)

Transform: Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])


In [27]:
def test():
    random_seed(42)
    start = gan.prior.sample((10,))
    pts = [start]
    
    gan.gen.zero_grad()
    gan.dis.zero_grad()

    z = pts[-1]
    # z = z.cpu()
    z = z.to(device)
    x = gan.gen(z)
    score = gan.dis(x)
    loss = criterion(score)
    loss.backward()

test()
print(next(gan.gen.parameters()).grad[0, 0])

tensor([[-0.0043,  0.0036, -0.0014, -0.0002],
        [-0.0695, -0.0025, -0.0022, -0.0017],
        [-0.0103, -0.0009, -0.0072,  0.0007],
        [ 0.0002,  0.0189,  0.0003,  0.0015]], device='cuda:0')


In [28]:
%timeit test()

13.2 ms ± 144 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [29]:
# don't keep graph

def test(verbose=False):
    random_seed(42)
    start = gan.prior.sample((10,))

    pts, meta = mcmc(
                "ula",
                start,
                ref_dist,
                gan.prior,
                n_samples=1,
                burn_in=0, #4,
                step_size=0.1,
                verbose=verbose,
                keep_graph=False,
            )

    gan.gen.zero_grad()
    gan.dis.zero_grad()

    z = pts[-1]
    # z = z.cpu()
    z = z.to(device)
    x = gan.gen(z)
    score = gan.dis(x)
    loss = criterion(score)
    loss.backward()

test(True)
print(next(gan.gen.parameters()).grad[0, 0])

100%|██████████| 1/1 [00:00<00:00, 102.60it/s]

tensor([[ 2.3322e-02,  1.4386e-03,  3.6168e-02,  1.2824e-03],
        [ 6.9809e-03,  2.4295e-03,  2.0020e-04,  2.1084e-03],
        [ 3.6275e-03,  1.9395e-02,  1.7134e-02,  3.8359e-03],
        [-2.9274e-03, -5.6348e-05, -4.6420e-02, -1.7411e-02]], device='cuda:0')





In [30]:
%timeit test()

23 ms ± 23.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [31]:
# keep graph, but detach the latent

def test(verbose=False):
    random_seed(42)
    start = gan.prior.sample((10,))

    pts, meta = mcmc(
                "ula",
                start,
                ref_dist,
                gan.prior,
                n_samples=1,
                burn_in=0, #,
                step_size=0.1,
                verbose=verbose,
                keep_graph=True,
            )

    gan.gen.zero_grad()
    gan.dis.zero_grad()

    z = pts[-1].detach()
    # z = z.cpu()
    z = z.to(device)
    x = gan.gen(z)
    score = gan.dis(x)
    loss = criterion(score)
    loss.backward()

test(True)
print(next(gan.gen.parameters()).grad[0, 0])

100%|██████████| 1/1 [00:00<00:00, 109.39it/s]

tensor([[ 2.3322e-02,  1.4387e-03,  3.6168e-02,  1.2823e-03],
        [ 6.9809e-03,  2.4295e-03,  2.0027e-04,  2.1084e-03],
        [ 3.6276e-03,  1.9395e-02,  1.7134e-02,  3.8360e-03],
        [-2.9274e-03, -5.6359e-05, -4.6421e-02, -1.7411e-02]], device='cuda:0')





In [32]:
%timeit test()

22.5 ms ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [33]:
# keep graph

def test(verbose=False):
    random_seed(42)
    start = gan.prior.sample((10,))
    gan.gen.zero_grad()
    gan.dis.zero_grad()

    pts, meta = mcmc(
                "ula",
                start,
                ref_dist,
                gan.prior,
                n_samples=1,
                burn_in=0, #4,
                step_size=0.1,
                verbose=verbose,
                keep_graph=True,
            )

    # gan.gen.zero_grad()
    # gan.dis.zero_grad()

    z = pts[-1]
    # z = z.cpu()
    z = z.to(device)
    x = gan.gen(z)
    score = gan.dis(x)
    loss = criterion(score)
    loss.backward()

test(True)
print(next(gan.gen.parameters()).grad[0, 0])

100%|██████████| 1/1 [00:00<00:00, 116.25it/s]

tensor([[-0.0218,  0.0145,  0.0209,  0.0265],
        [ 0.2115, -0.0253, -0.0253, -0.0046],
        [ 0.0988,  0.0197, -0.0153, -0.0073],
        [-0.0124, -0.1605, -0.0477,  0.0010]], device='cuda:0')





In [34]:
%timeit test()

35.2 ms ± 22.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
