In [None]:
from settings_distributed import *

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, CUDA_DEVICES))

print(os.environ["CUDA_VISIBLE_DEVICES"])

In [2]:
import sys

from functools import partial
from models import attgan
from settings_distributed import *
from tensorflow.keras import losses, optimizers, metrics

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import dataloader
import utils

In [3]:
tf.test.is_gpu_available()

Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.


True

In [4]:
strategy = tf.distribute.MirroredStrategy(devices=[
    f"/gpu:{idx}" for idx in range(len(CUDA_DEVICES))
])

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


In [5]:
train_dloader = dataloader.DataLoader("train", BATCH_SIZE)
valid_dloader = dataloader.DataLoader("valid", BATCH_SIZE)
test_dloader = dataloader.DataLoader("test", BATCH_SIZE)

with strategy.scope():
    train_dloader = tf.data.Dataset.from_generator(
        train_dloader.next_batch, 
        (tf.float32, tf.float32, tf.float32), 
        (
            tf.TensorShape([None, HEIGHT, WIDTH, CHANNEL]), 
            tf.TensorShape([None, NUM_ATT]), 
            tf.TensorShape([None, NUM_ATT])
        )
    ).shuffle(1024)
    train_dloader = strategy.experimental_distribute_dataset(train_dloader)

    valid_dloader = tf.data.Dataset.from_generator(
        valid_dloader.next_batch, 
        (tf.float32, tf.float32, tf.float32), 
        (
            tf.TensorShape([None, HEIGHT, WIDTH, CHANNEL]), 
            tf.TensorShape([None, NUM_ATT]), 
            tf.TensorShape([None, NUM_ATT])
        )
    )
    valid_dloader = strategy.experimental_distribute_dataset(valid_dloader)

    test_dloader = tf.data.Dataset.from_generator(
        test_dloader.next_batch, 
        (tf.float32, tf.float32, tf.float32), 
        (
            tf.TensorShape([None, HEIGHT, WIDTH, CHANNEL]), 
            tf.TensorShape([None, NUM_ATT]), 
            tf.TensorShape([None, NUM_ATT])
        )
    )
    test_dloader = strategy.experimental_distribute_dataset(test_dloader)

In [6]:
with strategy.scope():
    model = attgan.AttGAN()
    
    criterion_MSE = losses.MeanSquaredError(reduction=tf.losses.Reduction.SUM)
    criterion_BCE = losses.BinaryCrossentropy(reduction=tf.losses.Reduction.SUM)
    
    optimizer_gen = optimizers.Adam(learning_rate=ETA)
    optimizer_dis = optimizers.Adam(learning_rate=ETA)
    
    ckpt = tf.train.Checkpoint(model=model)
    # utils.load_model(ckpt, "ckpts/attgan/2020-03-07-20-12-03-Epoch12/ckpt-64")

    loss_mean_gen = metrics.Mean()
    loss_mean_dis = metrics.Mean()
    loss_valid = metrics.Mean()

In [7]:
# with strategy.scope():
#     def criterion_MAE(y_true, y_pred):
#         n = y_true.shape[0]

#         loss = tf.math.abs(y_true - y_pred)
#         loss = tf.reshape(loss, shape=(n, -1))
#         # loss = tf.reduce_sum(loss, axis=-1)
#         loss = tf.reduce_mean(loss)

#         return loss

#     def criterion_BCE(y_true, y_pred):
#         n = y_true.shape[0]

#         loss = - y_true * tf.math.log(y_pred + 1e-6) - (1 - y_true) * tf.math.log(1 - y_pred + 1e-6)
#         loss = tf.reshape(loss, (n, -1))
#         # loss = tf.reduce_sum(loss, axis=1)
#         loss = tf.reduce_mean(loss)

#         return loss

In [8]:
with strategy.scope():
    def generator_loss(x, b, x_rec_a, d_rec_b, c_rec_b, training=False):
        n = x.shape[0]

        loss_rec = criterion_MSE(x, x_rec_a)
        loss_adv_gen = criterion_BCE(tf.ones_like(d_rec_b), d_rec_b)
        loss_att_gen = criterion_BCE(b, c_rec_b)

        loss_gen = 10*loss_rec + loss_adv_gen + loss_att_gen

        return loss_gen

In [9]:
with strategy.scope():
    def discriminator_loss(x, a, x_rec_b, d_x, d_rec_b, c_x, training=False):
        loss_adv_dis = criterion_BCE(tf.ones_like(d_x), d_x) + criterion_BCE(tf.zeros_like(d_rec_b), d_rec_b)
        loss_att_dis = criterion_BCE(a, c_x)

        if training is True:
            gp = utils.wgan_gp(x, x_rec_b, model.disc, num_samples=30)
            loss_dis = loss_adv_dis + loss_att_dis + gp
        else:
            loss_dis = loss_adv_dis + loss_att_dis

        return loss_dis

In [10]:
with strategy.scope():
    def forward_prop(x, a, b, training, mode):
        n = x.shape[0]
        
        x_rec_a, c_rec_a, d_rec_a = model(x, a, training=training)
        x_rec_b, c_rec_b, d_rec_b = model(x, b, training=training)
        c_x, d_x = model.disc(x, training=training)

        # print(c_x)

        if mode == "generator":
            loss_gen = generator_loss(x, b, x_rec_a, d_rec_b, c_rec_b, training=training)
            return x_rec_a, x_rec_b, loss_gen
        elif mode == "discriminator":
            loss_dis = discriminator_loss(x, a, x_rec_b, d_x, d_rec_b, c_x, training=training)
            return x_rec_a, x_rec_b, loss_dis
        else:
            return None
    
    def inference(x, a, b, training):
        if training is True:
            model.train_mode("generator")
            with tf.GradientTape() as tape:
                x_rec_a, x_rec_b, loss_gen = forward_prop(x, a, b, training, "generator")
            
            grad_gen = tape.gradient(loss_gen, [*model.encoder.trainable_variables, *model.decoder.trainable_variables])
            optimizer_gen.apply_gradients(zip(grad_gen, [*model.encoder.trainable_variables, *model.decoder.trainable_variables]))
            loss_mean_gen.update_state(loss_gen)

            model.train_mode("discriminator")
            
            for i in range(5):
                with tf.GradientTape() as tape:
                    x_rec_a, x_rec_b, loss_dis = forward_prop(x, a, b, training, "discriminator")

                grad_dis = tape.gradient(loss_dis, model.disc.trainable_variables)
                optimizer_dis.apply_gradients(zip(grad_dis, model.disc.trainable_variables))
                loss_mean_dis.update_state(loss_dis)
            
        else:
            x_rec_a, x_rec_b, loss_gen = forward_prop(x, a, b, training, "generator")
            loss_valid.update_state(loss_gen)
            
        return x_rec_a, x_rec_b
    
    def concat_distributed_tensor(distributed_tensor):
        tensor_list = []
        
        for device in distributed_tensor.devices:
            tensor_list.append(distributed_tensor.get(device))
            
        concat_tensor = tf.concat(tensor_list, axis=0)
        return concat_tensor
    
    @tf.function
    def distributed_inference(x, a, b, training):
        x_rec_a, x_rec_b = strategy.experimental_run_v2(inference, args=(x, a, b, training))
        
        x = concat_distributed_tensor(x)
        x_rec_a = concat_distributed_tensor(x_rec_a)
        x_rec_b = concat_distributed_tensor(x_rec_b)
        
        return x, x_rec_a, x_rec_b

In [11]:
def get_mean_loss(loss_mean_obj):
    loss = loss_mean_obj.result()
    loss_mean_obj.reset_states()
    return loss

In [12]:
def detransform_image(image):
    image = image*256
    image = np.clip(image, 0, 255)
    image = image.astype(np.uint8)
    
    return image

def plot_images_with_index(images, indices):
    for i, index in enumerate(indices):
        image = images[index].numpy()
        image = detransform_image(image)
        
        plt.subplot(1, len(indices), i+1)
        plt.imshow(image)

def plot_images(x_origin, x_rec_a, x_rec_b):
    indices = np.random.randint(x_origin.shape[0], size=3)
    
    print("Origin:")
    plt.figure(figsize=(12, 3))
    plot_images_with_index(x_origin, indices)
    plt.show()
    
    print("Reconstructed with attribute a:")
    plt.figure(figsize=(12, 3))
    plot_images_with_index(x_rec_a, indices)
    plt.show()
    
    print("Reconstructed with attribute b:")
    plt.figure(figsize=(12, 3))
    plot_images_with_index(x_rec_b, indices)
    plt.show()

In [13]:
less_valid_loss = 323422.78125

In [None]:
with strategy.scope():
    for e in range(EPOCHS):
        print("===============================================================================")
        
        for x_batch, att_a_batch, att_b_batch in train_dloader:
            x, x_rec_a, x_rec_b = distributed_inference(x_batch, att_a_batch, att_b_batch, training=True)

        train_loss_gen = get_mean_loss(loss_mean_gen)
        train_loss_dis = get_mean_loss(loss_mean_dis)
            
        print("=== TRAIN SET ===")
        plot_images(x, x_rec_a, x_rec_b)

        for x_batch, att_a_batch, att_b_batch in valid_dloader:
            x, x_rec_a, x_rec_b = distributed_inference(x_batch, att_a_batch, att_b_batch, training=False)
            
        valid_loss = get_mean_loss(loss_valid)
        
        print("=== VALID SET ===")
        plot_images(x, x_rec_a, x_rec_b)
        
        print(f"Epochs {e+1}/{EPOCHS}")
        print(f"Train generator loss: {train_loss_gen:.8f}")
        print(f"Train discriminator loss: {train_loss_dis:.8f}")
        print(f"Valid loss: {valid_loss:.8f}")

        with open(LOG_FILE, "a") as logfile:
            logfile.write(f"Epochs {e+1}/{EPOCHS}\n")
            logfile.write(f"Train generator loss: {train_loss_gen:.8f}\n")
            logfile.write(f"Train discriminator loss: {train_loss_dis:.8f}\n")
            logfile.write(f"Valid loss: {valid_loss:.8f}\n")

        if less_valid_loss > valid_loss:
            less_valid_loss = valid_loss
            utils.save_model_with_source(model, ckpt, "ckpts/attgan", "models", e+1)

INFO:tensorflow:batch_all_reduce: 38 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:batch_all_reduce: 32 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:batch_all_reduce: 32 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:batch_all_reduce: 32 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:batch_all_reduce: 32 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:batch_all_reduce: 32 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:batch_all_reduce: 38 all-reduces with algorithm = nccl, num_packs = 1, a