In [1]:
%load_ext autoreload
%autoreload 2

from typing import Optional
import argparse
import datetime
import logging
import subprocess
from pathlib import Path
from collections import defaultdict

import numpy as np
import ruamel.yaml as yaml
import torch
import torchvision
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader

import sys
sys.path.append("../studiogan")  # noqa: E402
sys.path.append('..')

# from vizualization.plot_results import plot_res

from maxent_gan.datasets.utils import get_dataset
from maxent_gan.distribution import Distribution, DistributionRegistry
from maxent_gan.feature import BaseFeature, create_feature, FeatureRegistry
from maxent_gan.models.flow.real_nvp import RNVP  # noqa: F401
from maxent_gan.models.flow.real_nvp_minimal import RealNVPProposal
from maxent_gan.models.utils import GANWrapper
from maxent_gan.sample import MaxEntSampler
from maxent_gan.utils.callbacks import CallbackRegistry
from maxent_gan.utils.general_utils import DotConfig, IgnoreLabelDataset, random_seed, ROOT_DIR
from maxent_gan.utils.metrics.compute_fid_tf import calculate_fid_given_paths
from maxent_gan.utils.metrics.inception_score import (
    MEAN_TRASFORM,
    N_GEN_IMAGES,
    STD_TRANSFORM,
    get_inception_score,
)

from maxent_gan.models.studiogans import (  # noqa: F401, E402  isort: skip
    StudioDis,  # noqa: F401, E402  isort: skip
    StudioGen,  # noqa: F401, E402  isort: skip
)

2022-07-10 16:11:53.102334: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


Instructions for updating:
non-resource variables are not supported in the long term


In [None]:
def evaluate(
    feature, dataset, batch_size: int, device, save_path: Optional[Path] = None
):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    stats = defaultdict(lambda: 0.)
    n = 0
    for batch in tqdm(dataloader):
        feature_result = feature.apply(batch.to(device))
        for i, feature_res in enumerate(feature_result):
            stats[i] += feature_res.mean(0).detach().cpu().numpy()
        n += 1
    for i in range(len(stats)):
        stats[i] /= n

    if save_path:
        np.savez(
            save_path.open("wb"),
            *stats.values(),
        )
    return stats

def create_feature(config, gan, dataloader, dataset_stuff, save_dir, device):
    feature_callbacks = []
    callbacks = config.callbacks.feature_callbacks
    if callbacks:
        for _, callback in callbacks.items():
            params = callback.params.dict
            # HACK
            if "gan" in params:
                params["gan"] = gan
            if "save_dir" in params:
                params["save_dir"] = save_dir
            if "np_dataset" in params:
                np_dataset = np.concatenate(
                    [gan.inverse_transform(batch).numpy() for batch in dataloader], 0
                )
                params["np_dataset"] = np_dataset
            if "modes" in params:
                params["modes"] = dataset_stuff["modes"]
            feature_callbacks.append(CallbackRegistry.create(callback.name, **params))

    feature_kwargs = config.sample_params.feature.params.dict
    # HACK
    if "gan" in config.sample_params.feature.params:
        feature_kwargs["gan"] = gan
    if "dataloader" in config.sample_params.feature.params:
        feature_kwargs["dataloader"] = dataloader

    feature = FeatureRegistry.create(
        config.sample_params.feature.name,
        callbacks=feature_callbacks,
        inverse_transform=gan.inverse_transform,
        **feature_kwargs,
    )

    if config.sample_params.feature.params.ref_stats_path:
        feature.eval = True
        stats = evaluate(
            feature,
            dataset_stuff["dataset"],
            config.batch_size,
            device,
            Path(config.sample_params.feature.params.ref_stats_path),
        )
        # print(stats)
        feature = FeatureRegistry.create(
            config.sample_params.feature.name,
            callbacks=feature_callbacks,
            inverse_transform=gan.inverse_transform,
            **feature_kwargs,
        )
    feature.eval = False
    return feature

def define_sampler(
    config: DotConfig,
    gan: GANWrapper,
    ref_dist: Distribution,
    feature: BaseFeature,
    save_dir: Path,
):
    sampler_callbacks = []
    callbacks = config.callbacks.sampler_callbacks
    if callbacks:
        for _, callback in callbacks.items():
            params = callback.params.dict
            # HACK
            if "save_dir" in params:
                params["save_dir"] = save_dir
            sampler_callbacks.append(CallbackRegistry.create(callback.name, **params))
    sampler = MaxEntSampler(
        gan.gen,
        ref_dist,
        feature,
        **config.sample_params.params,
        callbacks=sampler_callbacks,
    )

    return sampler

In [3]:
from easydict import EasyDict as edict

args = edict()

configs = ['configs/exp_configs/dcgan-inception.yml', 'configs/targets/discriminator.yml', 'configs/gan_configs/cifar-10-dcgan.yml', 'configs/feature_configs/resnet34.yml', 'configs/mcmc_configs/ula.yml', 'configs/mcmc_exp.yml']
configs = [Path(ROOT_DIR, x).as_posix() for x in configs]
args.configs = configs

params = yaml.round_trip_load(Path(args.configs[0]).open("r"))

proc = subprocess.Popen("/bin/bash", stdin=subprocess.PIPE, stdout=subprocess.PIPE)
out, err = proc.communicate(
    (
        " ".join(
            [
                "echo",
                '"' + str(yaml.round_trip_dump(params)) + '"',
                "|",
                "cat - ",
                *args.configs[1:],
            ]
        )
    ).encode("utf-8")
)
config = yaml.round_trip_load(out.decode("utf-8"))
config = DotConfig(config)

In [4]:
dataset_info = get_dataset(
        config.gan_config.dataset.name,
        mean=config.gan_config.train_transform.Normalize.mean,
        std=config.gan_config.train_transform.Normalize.std,
        **config.gan_config.dataset.params,
    )
dataset = dataset_info["dataset"]
dataloader = DataLoader(dataset, batch_size=config.data_batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
device = torch.device(config.device if torch.cuda.is_available() else "cpu")

gan = GANWrapper(config.gan_config, device)
ref_dist = DistributionRegistry.create(
    config.sample_params.distribution.name, gan=gan
)

Transform: Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])


In [6]:
save_dir = Path('log', 'InceptionFeature_DiscriminatorTarget', 'dcgan_ula')
save_dir.mkdir(exist_ok=True, parents=True)

Path('stats').mkdir(exist_ok=True)

feature = create_feature(
    config, gan, dataloader, dataset_info, save_dir, device
)
sampler = define_sampler(config, gan, ref_dist, feature, save_dir)

100%|██████████| 120/120 [00:17<00:00,  6.88it/s]


defaultdict(<function evaluate.<locals>.<lambda> at 0x7fcd35579f70>, {0: array([5.36985993e-01, 2.99444944e-01, 3.31968457e-01, 5.00935137e-01,
       7.41526306e-01, 3.05099804e-02, 5.87054994e-03, 1.07310796e+00,
       1.04772735e+00, 2.63970947e+00, 5.86189151e-01, 1.03390467e+00,
       3.84846851e-02, 4.52814728e-01, 3.60871255e-01, 6.20623767e-01,
       1.74343333e-01, 1.20890689e+00, 2.58807361e-01, 1.05563700e+00,
       1.20465660e+00, 9.88488551e-03, 1.04650581e+00, 5.02035618e-01,
       3.33656579e-01, 4.19470698e-01, 1.59293860e-01, 2.31569910e+00,
       1.48171008e-01, 3.64308447e-01, 7.41894722e-01, 5.14527187e-02,
       9.57986861e-02, 1.37698364e+00, 7.66007185e-01, 2.50036359e-01,
       1.83898419e-01, 5.94058298e-02, 9.70878918e-03, 5.02519161e-02,
       1.85960725e-01, 1.60959020e-01, 4.17682230e-01, 2.19345883e-01,
       6.08515501e-01, 6.77864194e-01, 1.16486239e+00, 1.92253792e+00,
       1.02592576e+00, 1.35164214e-02, 6.55760169e-01, 1.76235688e+00,
    

In [7]:
start_latents = gan.prior.sample((config.sample_params.total_n,)).cpu()
start_step_id = 0
labels = torch.LongTensor(
    np.random.randint(
        0,
        dataset_info.get("n_classes", 10) - 1,
        config.sample_params.total_n,
    )
)

In [8]:
for i, start, label in zip(
        range(0, config.sample_params.total_n, config.sample_params.batch_size),
        torch.split(start_latents, config.sample_params.batch_size),
        torch.split(labels, config.sample_params.batch_size),
    ):
    print(i)

    if i > 0:
        feature.reset()
    
    if config.get("flow", None):
        gan.gen.prior = RealNVPProposal(gan.gen.z_dim, device=device)

    start = start.to(device)
    label = label.to(device)
    gan.set_label(label)
    
    zs, xs, _, _ = sampler(start)
    sampler.reset()
    gan.gen.input = gan.gen.output = gan.dis.input = gan.dis.output = None

    zs = torch.stack(zs, 0).cpu()
    xs = torch.stack(xs, 0).cpu()
    print(zs.shape)
    total_sample_z.append(zs)
    total_sample_x.append(xs)

0


  1%|          | 31/3001 [00:06<10:01,  4.94it/s]


KeyboardInterrupt: 