In [None]:
from google.colab import drive
drive.mount("/content/drive")

import os
import sys
sys.path.append("/content/drive/MyDrive/#fastcampus")
drive_project_root = "/content/drive/MyDrive/#fastcampus"
!pip install -r "/content/drive/MyDrive/#fastcampus/requirements.txt"

In [None]:
from datetime import datetime

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from omegaconf import DictConfig
import hydra
from hydra.core.config_store import ConfigStore

import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_addons as tfa

import wandb

In [None]:
from config_utils_tf import flatten_dict
from config_utils_tf import register_config
from config_utils_tf import get_optimizer_element
from config_utils_tf import get_callbacks

In [None]:
tf.config.list_physical_devices()

In [None]:
!nvidia-smi

## 모델 정의

In [None]:
class GAN(tf.keras.Model):
    """Convolutional variational autoencoder"""
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.latent_dim = cfg.model.latent_dim
        self.discriminator = tf.keras.Sequential(
            [
                tf.keras.layers.Conv2D(**cfg.model.dis.conv1),
                tf.keras.layers.Conv2D(**cfg.model.dis.conv2),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(1),
            ],
            name="discriminator",
        )
        self.generator = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(
                    units=cfg.model.gen.in_fc.units,
                    activation=tf.nn.relu
                ),
                tf.keras.layers.Reshape(
                    target_shape=tuple(cfg.model.gen.reshape_shape)
                ),
                tf.keras.layers.Conv2DTranspose(**cfg.model.gen.tr_conv1),
                tf.keras.layers.Conv2DTranspose(**cfg.model.gen.tr_conv2),
                tf.keras.layers.Conv2DTranspose(**cfg.model.gen.tr_conv3),
            ],
            name="generator",
        )
        self.generator_loss_tracker = tf.keras.metrics.Mean(name="generator_loss")
        self.discriminator_loss_tracker = tf.keras.metrics.Mean(name="discriminator_loss")
        self.cross_entropy = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.SUM
        )
    
    @property
    def metrics(self):
        return [
            self.generator_loss_tracker,
            self.discriminator_loss_tracker,
        ]
    
    def compile(self, d_optimizer, g_optimizer, **kwargs):
        super().compile(**kwargs)
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
    
    @tf.function
    def get_generator_loss(self, fake_outputs):
        return self.cross_entropy(
            tf.ones_like(fake_outputs), fake_outputs,
        ) / fake_outputs.shape[0]
    
    @tf.function
    def get_discriminator_loss(self, real_outputs, fake_outputs):
        real_loss = self.cross_entropy(
            tf.ones_like(real_outputs), real_outputs
        )
        fake_loss = self.cross_entropy(
            tf.zeros_like(fake_outputs), fake_outputs,
        )
        total_d_loss = (real_loss + fake_loss) / (real_outputs.shape[0] + fake_outputs.shape[0])
        return total_d_loss
    
    @tf.function
    def sample(self, sample_size=100):
        return tf.random.normal(shape=(sample_size, 1, 1, self.latent_dim))
    
    def call(self, _, training=False):
        noise = self.sample(sample_size=self.cfg.train.train_batch_size)
        generated_images = self.generator(noise, training=training)
        discriminator_results = self.discriminator(generated_images, training=training)
        return generated_images, discriminator_results
    
    def train_step(self, data):
        images, _ = data
        # images = [B X 28 X 28] -> [B X 28 X 28 X 1]
        images = tf.cast(tf.expand_dims(images, -1), tf.float32)
            
        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            outputs, fake_output = self(images, training=True)
            real_output = self.discriminator(images, training=True)

            # Calculate the loss for each item in the batch
            g_loss = self.get_generator_loss(fake_output)
            d_loss = self.get_discriminator_loss(real_output, fake_output)

        # compute gradients
        g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
        d_gradients = d_tape.gradient(d_loss, self.discriminator.trainable_variables)

        # update weights
        self.g_optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(d_gradients, self.discriminator.trainable_variables))

        # update the metrics
        self.generator_loss_tracker.update_state(g_loss)
        self.discriminator_loss_tracker.update_state(d_loss)

        # tensorboard image update
        tf.summary.image("train_real_img", images, max_outputs=5)
        tf.summary.image("train_generated_img", outputs, max_outputs=5)

        # return a dict mapping metrics names to current values
        logs = {m.name: m.result() for m in self.metrics}
        return logs

    def test_step(self, data):
        images, _ = data
        # images = [B X 28 X 28] -> [B X 28 X 28 X 1]
        images = tf.cast(tf.expand_dims(images, -1), tf.float32)
            
        outputs, fake_output = self(images, training=True)
        real_output = self.discriminator(images, training=True)

        # Calculate the loss for each item in the batch
        g_loss = self.get_generator_loss(fake_output)
        d_loss = self.get_discriminator_loss(real_output, fake_output)

        # update the metrics
        self.generator_loss_tracker.update_state(g_loss)
        self.discriminator_loss_tracker.update_state(d_loss)

        # tensorboard image update
        tf.summary.image("val_real_img", images, max_outputs=5)
        tf.summary.image("val_generated_img", outputs, max_outputs=5)

        # return a dict mapping metrics names to current values
        logs = {m.name: m.result() for m in self.metrics}
        return logs

In [None]:
class VAE(tf.keras.Model):
    """Convolutional variational autoencoder"""
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.latent_dim = cfg.model.latent_dim
        self.encoder = tf.keras.Sequential(
            [
                tf.keras.layers.Conv2D(**cfg.model.enc.conv1),
                tf.keras.layers.Conv2D(**cfg.model.enc.conv2),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(cfg.model.enc.out_fc.units),
            ],
            name="encoder",
        )
        self.decoder = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(
                    units=cfg.model.dec.in_fc.units,
                    activation=tf.nn.relu
                ),
                tf.keras.layers.Reshape(
                    target_shape=tuple(cfg.model.dec.reshape_shape)
                ),
                tf.keras.layers.Conv2DTranspose(**cfg.model.dec.tr_conv1),
                tf.keras.layers.Conv2DTranspose(**cfg.model.dec.tr_conv2),
                tf.keras.layers.Conv2DTranspose(**cfg.model.dec.tr_conv3),
            ],
            name="decoder",
        )
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.recon_loss_tracker = tf.keras.metrics.Mean(name="recon_loss")
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
    
    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.recon_loss_tracker,
            self.kl_loss_tracker,
        ]
    
    @tf.function
    def sample(self, epsilon=None, sample_size=100):
        if epsilon is None:
            epsilon = tf.random.normal(shape=(sample_size, self.latent_dim))
        return self.decode(epsilon)
    
    def encode(self, x, training=False):
        mu, logvar = tf.split(
            self.encoder(x, training=training),
            num_or_size_splits=2, 
            axis=1
        )
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """get z"""
        epsilon = tf.random.normal(shape=mu.shape)
        return mu + epsilon * tf.exp(logvar * .5)
    
    def decode(self, z, training=False):
        return tf.sigmoid(self.decoder(z, training=training))

    def call(self, input, training=False):
        mu, logvar = self.encode(input, training=training)
        z = self.reparameterize(mu, logvar)
        output = self.decode(z, training=training)
        return output, z, mu, logvar
    
    def train_step(self, data):
        images, _ = data
        # images = [B X 28 X 28] -> [B X 28 X 28 X 1]
        images = tf.cast(tf.expand_dims(images, -1), tf.float32)
            
        with tf.GradientTape() as tape:
            outputs, z, z_mu, z_logvar = self(images, training=True)
            
            # reconstuction loss
            recon_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.mae(images, outputs),
                    axis=(1, 2)
                )
            )

            # kld_loss
            kl_loss = -0.5 * (1 + z_logvar - tf.square(z_mu) - tf.exp(z_logvar))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

            # total_loss
            total_loss = recon_loss + kl_loss
        
        # compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(total_loss, trainable_vars)

        # update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # update the metrics
        self.total_loss_tracker.update_state(total_loss)
        self.recon_loss_tracker.update_state(recon_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        # tensorboard image update
        tf.summary.image("train_source_img", images, max_outputs=5)
        tf.summary.image("train_recon_img", outputs, max_outputs=5)

        # return a dict mapping metrics names to current values
        logs = {m.name: m.result() for m in self.metrics}
        return logs

    def test_step(self, data):
        images, _ = data
        # images = [B X 28 X 28] -> [B X 28 X 28 X 1]
        images = tf.cast(tf.expand_dims(images, -1), tf.float32)
            
        outputs, z, z_mu, z_logvar = self(images, training=False)
            
        # reconstuction loss
        recon_loss = tf.reduce_mean(
            tf.reduce_sum(
                tf.keras.losses.mae(images, outputs),
                axis=(1, 2)
            )
        )

        # kld_loss
        kl_loss = -0.5 * (1 + z_logvar - tf.square(z_mu) - tf.exp(z_logvar))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))

        # total_loss
        total_loss = recon_loss + kl_loss

        # update the metrics
        self.total_loss_tracker.update_state(total_loss)
        self.recon_loss_tracker.update_state(recon_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        # tensorboard image update
        tf.summary.image("val_source_img", images, max_outputs=5)
        tf.summary.image("val_recon_img", outputs, max_outputs=5)

        # return a dict mapping metrics names to current values
        logs = {m.name: m.result() for m in self.metrics}
        return logs


## Configuration 정의

In [None]:
# data configuration
data_fashion_mnist_cfg: dict = {
    "n_class": 10,
    "train_val_split": [0.9, 0.1],
    "train_val_shuffle": True,
    "train_val_shuffle_buffer_size": 1024,
    "test_shuffle": False,
    "test_shuffle_buffer_size": 1024,
}

# model configuration
model_mnist_vae_cfg: dict = {
    "name": "VAE",
    "data_normalize": True,
    "latent_dim": 2,
    "enc": {
        "conv1": {
            "filters": 32,
            "kernel_size": 3,
            "strides": [2, 2],
            "activation": "relu",
        },
        "conv2": {
            "filters": 64,
            "kernel_size": 3,
            "strides": [2, 2],
            "activation": "relu",
        },
        "out_fc": {
            "units": 4, # latent_dim * 2 (mu, log_var)
        },
    },
    "dec": {
        "in_fc": {
            "units": 7*7*32,
        },
        "reshape_shape": [7, 7, 32],
        "tr_conv1": {
            "filters": 64,
            "kernel_size": 3,
            "strides": [2, 2],
            "padding": "same",
            "activation": "relu",
        },
        "tr_conv2": {
            "filters": 32,
            "kernel_size": 3,
            "strides": [2, 2],
            "padding": "same",
            "activation": "relu",
        },
        "tr_conv3": {
            "filters": 1,
            "kernel_size": 3,
            "strides": [1, 1],
            "padding": "same",
        },
    }
}
model_mnist_gan_cfg: dict = {
    "name": "GAN",
    "data_normalize": True,
    "latent_dim": 64,
    "dis": {
        "conv1": {
            "filters": 32,
            "kernel_size": 3,
            "strides": [2, 2],
            "activation": "relu",
            "padding": "same",
        },
        "conv2": {
            "filters": 64,
            "kernel_size": 3,
            "strides": [2, 2],
            "activation": "relu",
            "padding": "same",
        },
    },
    "gen": {
        "in_fc": {
            "units": 7*7*32,
        },
        "reshape_shape": [7, 7, 32],
        "tr_conv1": {
            "filters": 64,
            "kernel_size": 3,
            "strides": [2, 2],
            "padding": "same",
            "use_bias": False,
            "activation": "relu",
        },
        "tr_conv2": {
            "filters": 32,
            "kernel_size": 3,
            "strides": [2, 2],
            "padding": "same",
            "use_bias": False,
            "activation": "relu",
        },
        "tr_conv3": {
            "filters": 1,
            "kernel_size": 3,
            "strides": [1, 1],
            "use_bias": False,
            "padding": "same",
            "activation": "sigmoid",
        },
    }
}


# optimizer_configs
adam_warmup_lr_sch_opt_cfg = {
    "optimizer": {
        "name": "Adam",
        "other_kwargs": {},
    },
    "lr_scheduler": {
        "name": "LinearWarmupLRSchedule",
        "kwargs": {
            "lr_peak": 1e-3,
            "warmup_end_steps": 1500,
        }
    }
}
radam_no_lr_sch_opt_cfg = {
    "optimizer": {
        "name": "RectifiedAdam",
        "learning_rate": 1e-3,
        "other_kwargs": {},
    },
    "lr_scheduler": None
}

gan_opt_cfg = {
    "discriminator": {
        "optimizer": {
            "name": "RMSprop",
            "learning_rate": 0.005,
            "other_kwargs": {},
        },
        "lr_scheduler": None,
    },
    "generator": {
        "optimizer": {
            "name": "Adam",
            "learning_rate": 0.001,
            "other_kwargs": {
                "beta_1": 0.5
            },
        },
        "lr_scheduler": None,
    }
}

# train_cfg
train_cfg: dict = {
    "train_batch_size": 256,
    "val_batch_size": 32,
    "test_batch_size": 32,
    "max_epochs": 50,
    "distribute_strategy": "MirroredStrategy",
}

_merged_cfg_presets = {
    "vae_fashion_mnist_radam": {
        "data": data_fashion_mnist_cfg,
        "model": model_mnist_vae_cfg,
        "opt": radam_no_lr_sch_opt_cfg,
        "train": train_cfg
    },
    "gan_fashion_mnist": {
        "data": data_fashion_mnist_cfg,
        "model": model_mnist_gan_cfg,
        "opt": gan_opt_cfg,
        "train": train_cfg
    }
}

### hydra composition ###
# clear hydra instance 
hydra.core.global_hydra.GlobalHydra.instance().clear()

# register preset configs
register_config(_merged_cfg_presets)

# initialization
hydra.initialize(config_path=None)

using_config_key = "gan_fashion_mnist"
cfg = hydra.compose(using_config_key)

# define & override log_cfg
model_name = cfg.model.name
run_dirname = "fastcampus_generative_model_tutorials_tf"
run_name = f"{datetime.now().isoformat(timespec='seconds')}-{using_config_key}-{model_name}"
log_dir = os.path.join(drive_project_root, "runs", run_dirname, run_name)

log_cfg = {
    "run_name": run_name,
    "callbacks": {
        "TensorBoard": {
            "log_dir": log_dir,
            "update_freq": 10,
        },
    },
    "wandb": {
        "project": "fastcampus_generative_model_tutorials_tf",
        "name": run_name,
        "tags": ["fastcampus_generative_model_tutorials_tf"],
        "reinit": True,
        "sync_tensorboard": True
    }
}

# unlock struct of config & set log config
OmegaConf.set_struct(cfg, False)
cfg.log = log_cfg

# relock config
OmegaConf.set_struct(cfg, True)
print(OmegaConf.to_yaml(cfg))

# save yaml
# with open(os.path.join(log_dir, "config.yaml")) as f:
# with open("config.yaml", "w") as f:
#     OmegaConf.save(cfg, f)


In [None]:
def get_distribute_strategy(strategy_name: str, **kwargs):
    return getattr(tf.distribute, strategy_name)(**kwargs)

distribute_strategy = get_distribute_strategy(cfg.train.distribute_strategy)

In [None]:
with distribute_strategy.scope():
    # 데이터 셋 정의 
    fashion_mnist = tf.keras.datasets.fashion_mnist
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    
    # normalization
    if cfg.model.data_normalize:
        x_train = x_train / 255.0
        x_test = x_test / 255.0

    # train/val splits
    assert sum(cfg.data.train_val_split) == 1.0
    train_size = int(len(x_train) * cfg.data.train_val_split[0])
    val_size = len(x_train) - train_size

    dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    if cfg.data.train_val_shuffle:
        dataset = dataset.shuffle(
            buffer_size=cfg.data.train_val_shuffle_buffer_size,
        )
    if cfg.data.test_shuffle:
        test_dataset = test_dataset.shuffle(
            buffer_size=cfg.data.test_shuffle_buffer_size,
        )

    train_dataset = dataset.take(train_size)
    val_dataset = dataset.skip(train_size)
    print(len(train_dataset), len(val_dataset), len(dataset), len(test_dataset))
    
    # dataloader 정의
    train_batch_size = cfg.train.train_batch_size
    val_batch_size = cfg.train.val_batch_size
    test_batch_size = cfg.train.test_batch_size

    train_dataloader = train_dataset.batch(train_batch_size, drop_remainder=True)
    val_dataloader = val_dataset.batch(val_batch_size, drop_remainder=True)
    test_dataloader = test_dataset.batch(test_batch_size, drop_remainder=True)

sample_example = next(iter(train_dataloader))
print(sample_example)

In [None]:
# 모델 정의 
def get_model(cfg: DictConfig):
    if cfg.model.name == "VAE":
        model = VAE(cfg)
    elif cfg.model.name == "GAN":
        model = GAN(cfg)
    else:
        raise NotImplementedError()
    return model


with distribute_strategy.scope():
    model = get_model(cfg)

    # define optimizer & scheduler
    if cfg.model.name == "GAN":
        d_optimizer, d_scheduler = get_optimizer_element(
            cfg.opt.discriminator.optimizer, cfg.opt.discriminator.lr_scheduler
        )
        g_optimizer, g_scheduler = get_optimizer_element(
            cfg.opt.generator.optimizer, cfg.opt.generator.lr_scheduler
        )
        model.compile(
            d_optimizer=d_optimizer,
            g_optimizer=g_optimizer,
        )
    else:
        optimizer, scheduler = get_optimizer_element(
            cfg.opt.optimizer, cfg.opt.lr_scheduler
        )
        model.compile(optimizer=optimizer)
        model.build((train_batch_size, 28, 28, 1))
        model.summary()

In [None]:
# get callbacks
callbacks = get_callbacks(cfg.log)

# wandb setup
wandb.init(
    config=flatten_dict(cfg),
    **cfg.log.wandb
)

In [None]:
%reload_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/\#fastcampus/runs/fastcampus_generative_model_tutorials_tf

model.fit(
    train_dataloader,
    validation_data=val_dataloader,
    epochs=cfg.train.max_epochs,
    callbacks=callbacks,
)

## VAE Model testing

In [None]:
def get_latent_img(model, n, single_img_size=28):
    """plot n x n images decoded from the latent_space"""

    norm = tfp.distributions.Normal(0, 1)
    grid_x = norm.quantile(np.linspace(0.05, 0.95, n))
    grid_y = norm.quantile(np.linspace(0.05, 0.95, n))
    width = single_img_size * n
    height = width
    image = np.zeros((height, width))
    
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z = np.array([[xi, yi]])
            x_decoded = model.sample(z)
            # x_decoded = model.decode(z)
            digit = tf.reshape(x_decoded[0], (single_img_size, single_img_size))        
            image[i*single_img_size:(i + 1)*single_img_size,
                  j*single_img_size:(j + 1)*single_img_size] = digit.numpy()
    return image

latent_img = get_latent_img(model, n=20)
plt.figure(figsize=(10, 10))
plt.imshow(latent_img, cmap="Greys_r")
plt.axis("Off")
plt.show()