## INIT DATA

In [11]:
from PIL import Image
import SimpleITK as sitk
from setup_data import setup_data, setup_feedback_data
from datetime import datetime

In [2]:
setup_data(r"../ImageData/",r"../TrainingData/",frames=10,parsing=3)
setup_feedback_data(r"../ImageData/",r"../TrainingData/",frames=5, future=5,parsing=3)

## Setup GPU

In [2]:
%load_ext tensorboard

In [5]:
import tensorflow as tf
import numpy as np
import tensorflow.keras as k
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adadelta, RMSprop,SGD,Adam
from tensorflow.python.client import device_lib

config = tf.compat.v1.ConfigProto(log_device_placement=True,
      gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.9))
      
sess = tf.compat.v1.Session(config=config)
physical_devices = tf.config.list_physical_devices('GPU') 
try:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  print('failed to set growth')

K.set_image_data_format('channels_first');

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce RTX 2070 SUPER, pci bus id: 0000:41:00.0, compute capability: 7.5



In [6]:
import tensorboard
tensorboard.__version__

'2.4.1'

## Train AE
### Load AE Data

In [7]:
data = np.load(r"../TrainingData/feedback_data.npz")
x_train = data['x_train']
y_train = data['y_train']
print(x_train[0].shape,y_train[0].shape)
x_test = data['x_test']
y_test = data['y_test']
print(x_test.shape,y_test.shape)

(5, 8, 128, 128) (5, 8, 128, 128)
(159, 5, 8, 128, 128) (159, 5, 8, 128, 128)


### Setup AE

In [26]:
EarlyStop = k.callbacks.EarlyStopping(monitor='loss',patience=5, restore_best_weights=True)
# Define the Keras TensorBoard callback.
aelogdir="logs/ae_fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = k.callbacks.TensorBoard(log_dir=aelogdir)

from FeedbackGenerator import FeedbackGenerator
try: del ae
except: pass
ae = FeedbackGenerator(do_batch_norm=False,use_noise=True)
ae.compile(loss='huber', optimizer = Adam(learning_rate=.0001),run_eagerly=True);
ae.build([50,5,8,128,128]);
ae.summary()

Model: "SimpleU"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Conv3D_0 (Conv3D)            multiple                  680       
_________________________________________________________________
Conv3D_1 (Conv3D)            multiple                  1360      
_________________________________________________________________
Conv3D_2 (Conv3D)            multiple                  4065      
_________________________________________________________________
Conv3D_3 (Conv3D)            multiple                  8120      
_________________________________________________________________
DownActivation_0 (PReLU)     multiple                  163840    
_________________________________________________________________
DownActivation_1 (PReLU)     multiple                  81920     
_________________________________________________________________
DownActivation_2 (PReLU)     multiple                  1536

### Run AE

In [14]:
ae.fit(x_train,y_train,epochs=100,batch_size=50,shuffle=True,callbacks=[tensorboard_callback])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x28cb70dae20>

In [15]:
ae.save_weights('ae_weights.h5')

In [27]:
ae.load_weights('ae_weights.h5')

In [28]:
import copy
from numpy.random import shuffle
import matplotlib.pyplot as plt
from PIL import Image

shuffle(viz := copy.deepcopy(x_test))
#vid = x_train[1:2,0:5,...]
vid = viz[1:2,...]
print(vid.shape)
for f in range(50):
  frame = ae.predict(vid[:,f:f+5,...])
  vid = np.append(vid,frame,axis=1)

def save_gif(frames,name:str):
    if name[-4:] != '.gif':
        name += '.gif'
    frames[0].save(name,format='GIF',append_images=frames[1:],save_all=True,loop=0)

save_gif([Image.fromarray(vid[0,j,4,:,:]*255) for j in range(10)],'AE_prediction.gif')

save_gif([Image.fromarray(vid[0,j,4,:,:]*255) for j in range(vid.shape[1])],'AE_super_prediction.gif')

(1, 5, 8, 128, 128)


In [18]:
try:
  del ae,x_train,y_train,x_test,y_test
except:
  pass

## Train GAN
### Load GAN Data

In [19]:
data = np.load(r"../TrainingData/data.npz")
x_train = data['n_frames']

### Setup GAN

In [20]:
# Define the Keras TensorBoard callback.
ganlogdir="logs/gan_fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = k.callbacks.TensorBoard(log_dir=ganlogdir)

from FeedbackGAN import FeedbackGAN
try: del gan
except: pass
gan = FeedbackGAN(channels=5, do_batch_norm=False,use_noise=True)
gan.compile(loss='binary_crossentropy', g_optimizer = Adam(learning_rate=.00001), d_optimizer = Adam(learning_rate=.01),run_eagerly=True);
gan.build([50,5,8,128,128]);
gan.summary()

gan.Generator.load_weights('ae_weights.h5')

Model: "feedback_gan"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Generator (SimpleUGen)       multiple                  1458491   
_________________________________________________________________
Discriminator (discriminator multiple                  36073     
_________________________________________________________________
concatenate (Concatenate)    multiple                  0         
Total params: 1,494,566
Trainable params: 1,494,564
Non-trainable params: 2
_________________________________________________________________


### Run GAN

In [21]:
gan.fit(x_train,x_train,epochs=100,batch_size=15,shuffle=True,callbacks=[tensorboard_callback]);

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [22]:
gan.save_weights('gan_weights.h5')

## Visualize

In [24]:
gan.load_weights('gan_weights.h5')

In [25]:
from PIL import Image
import matplotlib.pyplot as plt

vid = x_train[1:2,0:5,...]
print(vid.shape)
for f in range(50):
  frame = gan.Generator.predict(vid[:,f:f+5,...])
  vid = np.append(vid,frame,axis=1)
  #plt.imshow(frame[0,0,4,:,:],cmap='gray')
  #plt.figure()

def save_gif(frames,name:str):
    if name[-4:] != '.gif':
        name += '.gif'
    print(len(frames))
    frames[0].save(name,format='GIF',append_images=frames[1:],save_all=True,loop=0)
print(vid.shape)
save_gif([Image.fromarray(vid[0,j,4,:,:]*255) for j in range(10)],'GAN_prediction.gif')
print(vid[0,:,4,:,:].shape)
save_gif([Image.fromarray(vid[0,j,4,:,:]*255) for j in range(vid.shape[1])],'GAN_super_prediction.gif')

(1, 5, 8, 128, 128)
(1, 55, 8, 128, 128)
10
(55, 128, 128)
55


In [3]:
%tensorboard --logdir logs

Reusing TensorBoard on port 6006 (pid 23208), started 3:50:46 ago. (Use '!kill 23208' to kill it.)