In [3]:
import os
from os import listdir, path
import numpy as np
import pickle, argparse
from glob import glob
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import clear_output

#os.environ["CUDA_VISIBLE_DEVICES"]="1"
tf.debugging.set_log_device_placement(True)
print(tf.__version__)
tf.config.list_physical_devices('XLA_GPU')

2.1.0


[PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:1', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:2', device_type='XLA_GPU')]

In [4]:
from easydict import EasyDict

args = EasyDict(data_root='LipGAN_dataset_local',
               batch_size=192,
               lr=1e-3,
               img_size=96,
               logdir='logs_ipynb',
               all_images='filenames.pkl')
print(args)

{'data_root': 'LipGAN_dataset_local', 'batch_size': 192, 'lr': 0.001, 'img_size': 96, 'logdir': 'logs_ipynb', 'all_images': 'filenames.pkl'}


### Dataset

In [5]:
import itertools 

half_window_size = 4
mel_step_size = 27
    
def frame_id(fname):
    return int(os.path.basename(fname).split('.')[0])

def choose_ip_frame(frames, gt_frame):
    selected_frames = [f for f in frames if np.abs(frame_id(gt_frame) - frame_id(f)) >= 6]
    if len(selected_frames) == 0:
        selected_frames = frames
        
    return np.random.choice(selected_frames)

def get_audio_segment(center_frame, spec):
    center_frame_id = frame_id(center_frame)
    start_frame_id = center_frame_id - half_window_size

    start_idx = int((81./25.) * start_frame_id) # 25 is fps of LRS2
    end_idx = start_idx + mel_step_size

    return spec[:, start_idx : end_idx] if end_idx <= spec.shape[1] else None

def bgr2rgb(x):
    temp = x[:, :, 0].copy()
    x[:, :, 0] = x[:, :, 2]
    x[:, :, 2] = temp
    
    return x

if not path.exists(args.logdir):
    os.mkdir(args.logdir)

if path.exists(path.join(args.logdir, args.all_images)):
    all_images = pickle.load(open(path.join(args.logdir, args.all_images), 'rb'))
else:
    all_images = glob(path.join("{}/train/*/*/*.jpg".format(args.data_root)))
    pickle.dump(all_images, open(path.join(args.logdir, args.all_images), 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

print ("Will be training on {} images".format(len(all_images)))

np.random.shuffle(all_images)
batches = all_images

def gen(): 
    for i in itertools.count(1):
        while(True):
            index = np.random.randint(0, len(batches))

            '''Get a frame'''
            img_name = batches[index]
            gt_fname = os.path.basename(img_name)
            dir_name = img_name.replace(gt_fname, '')
            frames = glob(dir_name + '/*.jpg')

            if len(frames) < 12:
                continue

            '''Get a melspectrogram'''
            mel_fname = dir_name + "./mels.npz"
            mel = np.load(mel_fname)['spec']
            mel = get_audio_segment(gt_fname, mel)

            if mel is None or mel.shape[1] != mel_step_size:
                continue

            if sum(np.isnan(mel.flatten())) > 0:
                continue    

            '''Ground Truth & IP '''
            img_gt = cv2.imread(img_name)        
            img_gt = bgr2rgb(img_gt)
            img_gt = img_gt / 255.0
            img_gt = cv2.resize(img_gt, (args.img_size, args.img_size))

            img_gt_masked = img_gt.copy()
            img_gt_masked[args.img_size//2:] = 0 

            ip_fname = choose_ip_frame(frames, gt_fname)
            img_ip = cv2.imread(ip_fname)
            img_ip = bgr2rgb(img_ip)
            img_ip = img_ip / 255.0
            img_ip = cv2.resize(img_ip, (args.img_size, args.img_size))

            break

        yield (img_gt, img_gt_masked, img_ip, mel)
        
dataset = tf.data.Dataset.from_generator(gen, (tf.float32, tf.float32, tf.float32, tf.float32))
dataset = dataset.batch(batch_size=args.batch_size)

Will be training on 1678240 images
Executing op TensorDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op FlatMapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0


### Genarator

In [6]:
from tensorflow.keras import layers

class Generator(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        class ConvBlock(tf.keras.Model):
            def __init__(self, num_filters, kernel_size=3, strides=1, padding='same', act=True):
                super().__init__()

                self.conv = layers.Conv2D(filters=num_filters, kernel_size=kernel_size, strides=strides, padding=padding)
                self.batch_norm = layers.BatchNormalization(momentum=0.8)
                self.act = act

            def call(self, x):
                x = self.conv(x)
                x = self.batch_norm(x)
                if self.act:
                    x = tf.nn.relu(x)

                return x

        class TransposedConvBlock(tf.keras.Model):
            def __init__(self, num_filters, kernel_size=3, strides=2, padding='same'):
                super().__init__()

                self.conv = layers.Conv2DTranspose(filters=num_filters, kernel_size=kernel_size, strides=strides, padding=padding)
                self.batch_norm = layers.BatchNormalization(momentum=0.8)

            def call(self, x):
                x = self.conv(x)
                x = self.batch_norm(x)
                x = tf.nn.relu(x)

                return x


        class FaceEncoder(tf.keras.Model):
            def __init__(self):
                super().__init__()
                self.convs = [ConvBlock(32, 11),
                              ConvBlock(64, 7, 2),
                              ConvBlock(128, 5, 2),
                              ConvBlock(256, 3, 2),
                              ConvBlock(512, 3, 2),
                              ConvBlock(512, 3, 2),
                              ConvBlock(512, 3, 1, padding='valid'),
                              ConvBlock(256, 1, 1)]

            def call(self, x):
                # x: (B, H=96, W=96, C=6)
                outputs = []
                for conv in self.convs:
                    x = conv(x)
                    outputs.append(x)

                return outputs

        class AudioEncoder(tf.keras.Model):
            def __init__(self):
                super().__init__()
                self.convs = tf.keras.Sequential([ConvBlock(32),
                                              ConvBlock(32),
                                              ConvBlock(32),

                                              ConvBlock(64, strides=3), #27x9
                                              ConvBlock(64),
                                              ConvBlock(64),

                                              ConvBlock(128, strides=(3, 1)), #9x9
                                              ConvBlock(128),
                                              ConvBlock(128),

                                              ConvBlock(256, strides=3), #3x3
                                              ConvBlock(256),
                                              ConvBlock(256),

                                              ConvBlock(512, strides=1, padding='valid'), #1x1
                                              ConvBlock(512, 1, 1)])

            def call(self, x):
                # x: (B, H=80, W=27)

                # (B, H, W, 1)
                x = tf.expand_dims(x, axis=3)

                outputs = self.convs(x)

                return outputs

        class FaceDecoder(tf.keras.Model):
            def __init__(self):
                super().__init__()
                self.convs = [tf.keras.Sequential([ConvBlock(512, 1),
                                                   TransposedConvBlock(512, 3, 3)]),
                              TransposedConvBlock(512),
                              TransposedConvBlock(256),
                              TransposedConvBlock(128),
                              TransposedConvBlock(64),
                              TransposedConvBlock(32),                    
                              tf.keras.Sequential([ConvBlock(16),
                                                   ConvBlock(16),
                                                   layers.Conv2D(filters=3, kernel_size=1, strides=1, padding='same')])]

            def call(self, face_encoded, audio_encoded):
                # face_encoded: 0(B, 96, 96, 32)
                #               1(B, 48, 48, 64)
                #               2(B, 24, 24, 128)
                #               3(B, 12, 12, 256)
                #               4(B, 6, 6, 512)
                #               5(B, 3, 3, 512)
                #               6(B, 1, 1, 512)
                #               7(B, 1, 1, 256)
                # audio_encoded: (B, 1, 1, 512)

                # (B, 1, 1, 768)
                x = tf.concat([face_encoded[7], audio_encoded], axis=3)

                # (B, 3, 3, 512)
                x = self.convs[0](x)

                # (B, 3, 3, 1024)
                x = tf.concat([face_encoded[5], x], axis=3)

                # (B, 6, 6, 512)
                x = self.convs[1](x)
                # (B, 6, 6, 1024)
                x = tf.concat([face_encoded[4], x], axis=3)

                # (B, 12, 12, 256)
                x = self.convs[2](x)
                # (B, 12, 12, 512)
                x = tf.concat([face_encoded[3], x], axis=3)

                # (B, 24, 24, 128)
                x = self.convs[3](x)
                # (B, 24, 24, 256)
                x = tf.concat([face_encoded[2], x], axis=3)

                # (B, 48, 48, 64)
                x = self.convs[4](x)
                # (B, 48, 48, 128)
                x = tf.concat([face_encoded[1], x], axis=3)

                # (B, 96, 96, 32)
                x = self.convs[5](x)
                # (B, 96, 96, 64)
                x = tf.concat([face_encoded[0], x], axis=3)

                # (B, 96, 96, 3)
                x = self.convs[6](x)
                x = tf.nn.sigmoid(x)

                return x

        self.face_encoder = FaceEncoder()
        self.audio_encoder = AudioEncoder()
        self.face_decoder = FaceDecoder()
        
    def call(self, img_gt_masked, img_ip, mel):
        # img_gt_masked: (B, H=96, W=96, C=3)
        # img_ip: (B, H=96, W=96, C=3)
        # mels: (B, H=80, W=27)
        
        # (B, H=96, W=96, C=6)
        input_face = tf.concat([img_gt_masked, img_ip], axis=3)
        
        # (B, 96, 96, 32)
        # (B, 48, 48, 64)
        # (B, 24, 24, 128)
        # (B, 12, 12, 256)
        # (B, 6, 6, 512)
        # (B, 3, 3, 512)
        # (B, 1, 1, 512)
        # (B, 1, 1, 256)
        face_encoded = self.face_encoder(input_face)
        
        # (B, 1, 1, 512)
        audio_encoded = self.audio_encoder(mel)
                
        # (B, 96, 96, 3)
        pred = self.face_decoder(face_encoded, audio_encoded)
        
        return pred
  


### Discriminator

In [7]:
import tensorflow_addons as tfa

class Discriminator(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        class ConvBlock(tf.keras.Model):
            def __init__(self, num_filters, kernel_size=3, strides=2, padding='same'):
                super().__init__()

                self.conv = layers.Conv2D(filters=num_filters, kernel_size=kernel_size, strides=strides, padding=padding)
                self.norm = tfa.layers.InstanceNormalization()

            def call(self, x):
                x = self.conv(x)
                x = self.norm(x)
                x = tf.nn.leaky_relu(x)

                return x

        class FaceEncoder(tf.keras.Model):
            def __init__(self):
                super().__init__()
                self.convs = tf.keras.Sequential([ConvBlock(64, 7),
                                                 ConvBlock(128, 5),
                                                 ConvBlock(256, 3),
                                                 ConvBlock(512, 3),
                                                 ConvBlock(512, 3),
                                                 layers.Conv2D(filters=512, kernel_size=3, strides=1, padding='valid'),
                                                 layers.Flatten()])

            def call(self, x):
                # x: (B, H=96, W=96, C=6)
                output = self.convs(x)
                output = tf.math.l2_normalize(output, axis=1)
                return output

        class AudioEncoder(tf.keras.Model):
            def __init__(self):
                super().__init__()
                self.convs = tf.keras.Sequential([ConvBlock(32, strides=1),
                                                  ConvBlock(64, strides=3),
                                                  ConvBlock(128, strides=(3, 1)),
                                                  ConvBlock(256, strides=3),
                                                  ConvBlock(512, strides=1, padding='valid'),
                                                  ConvBlock(512, 1, strides=1),
                                                  layers.Flatten()])

            def call(self, x):
                # x: (B, H=80, W=27)

                # (B, H, W, 1)
                x = tf.expand_dims(x, axis=3)

                output = self.convs(x)
                #output = tf.math.l2_normalize(output, axis=1)

                return output

        self.face_encoder = FaceEncoder()
        self.audio_encoder = AudioEncoder()
        
    def call(self, img, mels):
        # img: (B, H=96, W=96, C=3)
        # mels: (B, H=80, W=27)
        
        # (B, 512)
        face_embedding = self.face_encoder(img)
        
        # (B, 512)
        audio_embedding = self.audio_encoder(mels)
                
        distance = tf.norm(face_embedding - audio_embedding, axis=1, keepdims=False)
        
        return distance
  


In [8]:
import os
from tensorboardX import SummaryWriter

class Logger(SummaryWriter):
    def __init__(self, logdir):
        super(Logger, self).__init__(logdir)

    def log(self, log_string, value, iteration):
            self.add_scalar(log_string, value, iteration)

save_dir = 'save/LipGAN_tf2'
logger = Logger(save_dir)

!ls $save_dir

checkpoint
ckpt-1.data-00000-of-00001
ckpt-1.index
ckpt-2.data-00000-of-00001
ckpt-2.index
ckpt-3.data-00000-of-00001
ckpt-3.index
events.out.tfevents.1586931164.scpark-X299-WU8
events.out.tfevents.1586931291.scpark-X299-WU8
events.out.tfevents.1586931347.scpark-X299-WU8
events.out.tfevents.1586931350.scpark-X299-WU8
events.out.tfevents.1586931352.scpark-X299-WU8
events.out.tfevents.1586931799.scpark-X299-WU8
events.out.tfevents.1586936434.scpark-X299-WU8
events.out.tfevents.1586938531.scpark-X299-WU8
events.out.tfevents.1586938600.scpark-X299-WU8
events.out.tfevents.1586939453.scpark-X299-WU8
events.out.tfevents.1586942041.scpark-X299-WU8


In [9]:
generator = Generator()
discriminator = Discriminator()
print(generator, discriminator)

gen_opt = tf.keras.optimizers.Adam(args.lr)
disc_opt = tf.keras.optimizers.Adam(args.lr)

if True:
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), 
                                   generator_optimizer=gen_opt, 
                                   discriminator_optimizer=disc_opt,
                                   generator=generator,
                                   discriminator=generator)
    manager = tf.train.CheckpointManager(ckpt, save_dir, max_to_keep=5)
    ckpt.restore(manager.latest_checkpoint)
    
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

<__main__.Generator object at 0x7fc9f3566cd0> <__main__.Discriminator object at 0x7fc9f3566c90>
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarIsInitializedOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Assert in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RestoreV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RestoreV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Assert in device

In [10]:
def contrastive_loss(y_true, y_pred):
    margin = 1.
    loss = (1. - y_true) * y_pred ** 2. + y_true * tf.math.maximum(0., margin - y_pred) ** 2.
    loss = tf.math.reduce_mean(loss)
    
    return loss

@tf.function
def train_step(img_gt, img_gt_masked, img_ip, mel):
    with tf.GradientTape() as tape:
        generated = generator(img_gt_masked, img_ip, mel)
        l1_loss = tf.reduce_mean(tf.abs(img_gt - generated))
        
        generated_pred = discriminator(generated, mel)
        generated_loss = contrastive_loss(0., generated_pred)
        
        gen_loss = l1_loss - generated_loss
        
    gen_grad = tape.gradient(gen_loss, generator.trainable_variables)
    gen_opt.apply_gradients(zip(gen_grad, generator.trainable_variables))
        
    with tf.GradientTape() as tape:
        
        generated_pred = discriminator(generated, mel)
        generated_loss = contrastive_loss(0., generated_pred)
        
        real_pred = discriminator(img_gt, mel)
        sync_loss = contrastive_loss(1., real_pred)
        
        unsync_pred = discriminator(img_gt, tf.roll(mel, shift=10, axis=0))
        unsync_loss = contrastive_loss(0., unsync_pred)
        
        disc_loss = generated_loss + sync_loss + unsync_loss
        
    disc_grad = tape.gradient(disc_loss, discriminator.trainable_variables)
    disc_opt.apply_gradients(zip(disc_grad, discriminator.trainable_variables))
    
    return generated, l1_loss, generated_loss, sync_loss, unsync_loss

In [11]:
while(True):
    for _, data in enumerate(dataset):
        step = ckpt.step.numpy()
        img_gt, img_gt_masked, img_ip, mel = data
        generated, l1_loss, generated_loss, sync_loss, unsync_loss = train_step(img_gt, img_gt_masked, img_ip, mel)
        print(step)
        print('l1:', l1_loss.numpy(), 'gen:', generated_loss)
        print('sync:', sync_loss.numpy(), 'unsync:', unsync_loss)
        
        if step % 10 == 0:
            logger.log('loss', loss.numpy(), step)
        
        if step % 1000 == 0:
            save_path = manager.save()
            print("Saved checkpoint for step {}: {}".format(step, save_path))
            print("loss {:1.4f}".format(loss.numpy()))
        
        if step % 100 == 0:
            clear_output()
            
            _img_gt = img_gt.numpy()
            _generated = generated.numpy()
            _img_ip = img_ip.numpy()
            
            plt.figure(figsize=[10, 4])
            plt.subplot(1, 3, 1)
            plt.title('Ground Truth')
            plt.imshow(_img_gt[0])
            
            plt.subplot(1, 3, 2)
            plt.title('Predicted')
            plt.imshow(_generated[0])
            
            plt.subplot(1, 3, 3)
            plt.title('Unsynced')
            plt.imshow(_img_ip[0])
            plt.show()
        
        ckpt.step.assign_add(1)

Executing op OptimizeDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ModelDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op IteratorGetNextSync in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RandomUniform in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Sub in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Mul in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Add in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AssignVariableOp in device 

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/scpark/anaconda3/envs/ai/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-11-e79838b3a21c>", line 5, in <module>
    generated, l1_loss, generated_loss, sync_loss, unsync_loss = train_step(img_gt, img_gt_masked, img_ip, mel)
  File "/home/scpark/anaconda3/envs/ai/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__
    result = self._call(*args, **kwds)
  File "/home/scpark/anaconda3/envs/ai/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 632, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/scpark/anaconda3/envs/ai/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2362, in __call__
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
  File "/home/scpark/anaconda3/envs/ai/lib/python

KeyboardInterrupt: 

In [None]:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(step, save_path))
print("loss {:1.4f}".format(loss.numpy()))

In [3]:
tf.__version__

'2.1.0'

In [12]:
print(tf.test.is_built_with_cuda()) 

True


In [26]:
with tf.device('/GPU:0'):
    c = tf.constant(0)
    print(c)

tf.Tensor(0, shape=(), dtype=int32)
