In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
import yaml
from pathlib import Path
import seaborn as sns
from matplotlib import pyplot as plt

##
import sys
sys.path.append("../studiogan")
import studiogan
from maxent_gan.models.studiogans import StudioDis, StudioGen
##

from maxent_gan.utils.general_utils import DotConfig, ROOT_DIR, random_seed
from maxent_gan.models.utils import GANWrapper
from maxent_gan.distribution import Distribution, DiscriminatorTarget
from maxent_gan.mcmc import MCMCRegistry, ula, mala
from maxent_gan.utils.train.loss import LossRegistry


mcmc = MCMCRegistry()

batch_size = 100

2022-07-08 00:15:18.434451: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


In [3]:
dataset = 'cifar10'
model = 'dcgan'

gan_config = yaml.safe_load(Path(ROOT_DIR, f'configs/gan_configs/cifar-10-{model}.yml').open('r'))
gan_config = DotConfig(gan_config['gan_config'])

device = torch.device(0 if torch.cuda.is_available() else 'cpu')

gan = GANWrapper(gan_config, device, eval=False)
criterion = LossRegistry.create('JensenNSLoss')
ref_dist = DiscriminatorTarget(gan, batch_size=batch_size)

Transform: Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])


In [6]:
# don't keep graph

random_seed(42)
start = gan.prior.sample((10,))

pts, meta = mcmc(
            "mala",
            start,
            ref_dist,
            gan.prior,
            n_samples=1,
            burn_in=4,
            step_size=0.1,
            verbose=True,
            keep_graph=False,
        )

gan.gen.zero_grad()
gan.dis.zero_grad()

z = pts[-1]
z = z.cpu()
z = z.to(device)
x = gan.gen(z)
score = gan.dis(x)
loss = criterion(score)
loss.backward()
print(next(gan.gen.parameters()).grad[0, 0])

100%|██████████| 5/5 [00:00<00:00, 16.69it/s]

tensor([[ 0.0317,  0.0040, -0.0063,  0.0057],
        [-0.0671, -0.0086, -0.0076, -0.0040],
        [-0.0317, -0.0028, -0.0195,  0.0006],
        [-0.0035, -0.0268, -0.0001,  0.0076]], device='cuda:0')





In [7]:
# keep graph, but detach the latent

random_seed(42)
start = gan.prior.sample((10,))

pts, meta = mcmc(
            "mala",
            start,
            ref_dist,
            gan.prior,
            n_samples=1,
            burn_in=4,
            step_size=0.1,
            verbose=True,
            keep_graph=True,
        )

gan.gen.zero_grad()
gan.dis.zero_grad()

z = pts[-1].detach()
z = z.cpu()
z = z.to(device)
x = gan.gen(z)
score = gan.dis(x)
loss = criterion(score)
loss.backward()

print(next(gan.gen.parameters()).grad[0, 0])

100%|██████████| 5/5 [00:00<00:00, 18.57it/s]

tensor([[ 0.0317,  0.0040, -0.0063,  0.0057],
        [-0.0671, -0.0086, -0.0076, -0.0040],
        [-0.0317, -0.0028, -0.0195,  0.0006],
        [-0.0035, -0.0268, -0.0001,  0.0076]], device='cuda:0')





In [8]:
# keep graph

random_seed(42)
start = gan.prior.sample((10,))

pts, meta = mcmc(
            "mala",
            start,
            ref_dist,
            gan.prior,
            n_samples=1,
            burn_in=4,
            step_size=0.1,
            verbose=True,
            keep_graph=True,
        )

gan.gen.zero_grad()
gan.dis.zero_grad()

z = pts[-1]
z = z.cpu()
z = z.to(device)
x = gan.gen(z)
score = gan.dis(x)
loss = criterion(score)
loss.backward()

print(next(gan.gen.parameters()).grad[0, 0])

100%|██████████| 5/5 [00:00<00:00, 19.79it/s]


tensor([[ 0.0538,  0.0057,  0.0025, -0.0018],
        [-0.0694,  0.0061,  0.0059,  0.0024],
        [-0.0842,  0.0003, -0.0204,  0.0061],
        [ 0.0026, -0.0599,  0.0027,  0.0030]], device='cuda:0')
