# Testing Underwater GAN (UGAN) in Colab

## Prepare data

### Clone repo

In [0]:
# clone repo
%cd /content/
!if ! test -d Underwater-Color-Correction/; then git clone https://github.com/yoelrc88/Underwater-Color-Correction.git ; fi  
%cd Underwater-Color-Correction/

### Download dataset and checkpoint

If something we need to download something from a google drive link (models, datasets)

In [0]:
#install gdrive dependencies
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authentication
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# Download file GDrive to Colab

!echo "Downloading underwater_imagenet.zip ..."
file_downloaded = drive.CreateFile({'id': '1LOM-2A1BSLaFjCY2EEK3DA2Lo37rNw-7'})
file_downloaded.GetContentFile('underwater_imagenet.zip') # name of the file n GDrive

!echo "Downloading old_checkpoint.zip ..."
file_downloaded = drive.CreateFile({'id': '1VOpQPHOGi3bYX2C-0AWzSfaJVUBHlacU'})
file_downloaded.GetContentFile('old_checkpoint.zip') # name of the file n GDrive

# uncompress files
!unzip ./underwater_imagenet.zip -d ./
!unzip ./old_checkpoint.zip -d checkpoints/


In [0]:
'''

   Create pickle file

'''

import cPickle as pickle
import tensorflow as tf
from scipy import misc
from tqdm import tqdm
import numpy as np
import argparse
import ntpath
import random
import glob
import time
import sys
import cv2
import os

# my imports
sys.path.insert(0, 'ops/')
sys.path.insert(0, 'nets/')
from tf_ops import *
import data_ops

if __name__ == '__main__':
   parser = argparse.ArgumentParser()
   parser.add_argument('--LEARNING_RATE', required=False,default=1e-4,type=float,help='Learning rate')
   parser.add_argument('--LOSS_METHOD',   required=False,default='wgan',help='Loss function for GAN')
   parser.add_argument('--BATCH_SIZE',    required=False,default=32,type=int,help='Batch size')
   parser.add_argument('--L1_WEIGHT',     required=False,default=100.,type=float,help='Weight for L1 loss')
   parser.add_argument('--IG_WEIGHT',     required=False,default=1.,type=float,help='Weight for image gradient loss')
   parser.add_argument('--NETWORK',       required=False,default='pix2pix',type=str,help='Network to use')
   parser.add_argument('--AUGMENT',       required=False,default=0,type=int,help='Augment data or not')
   parser.add_argument('--EPOCHS',        required=False,default=100,type=int,help='Number of epochs for GAN')
   parser.add_argument('--DATA',          required=False,default='underwater_imagenet',type=str,help='Dataset to use')
   a = parser.parse_args()

   LEARNING_RATE = float(a.LEARNING_RATE)
   LOSS_METHOD   = a.LOSS_METHOD
   BATCH_SIZE    = a.BATCH_SIZE
   L1_WEIGHT     = float(a.L1_WEIGHT)
   IG_WEIGHT     = float(a.IG_WEIGHT)
   NETWORK       = a.NETWORK
   AUGMENT       = a.AUGMENT
   EPOCHS        = a.EPOCHS
   DATA          = a.DATA
   
   EXPERIMENT_DIR  = 'checkpoints/LOSS_METHOD_'+LOSS_METHOD\
                     +'/NETWORK_'+NETWORK\
                     +'/L1_WEIGHT_'+str(L1_WEIGHT)\
                     +'/IG_WEIGHT_'+str(IG_WEIGHT)\
                     +'/AUGMENT_'+str(AUGMENT)\
                     +'/DATA_'+DATA+'/'\

   IMAGES_DIR      = EXPERIMENT_DIR+'images/'

   print
   print 'Creating',EXPERIMENT_DIR
   try: os.makedirs(IMAGES_DIR)
   except: pass
   try: os.makedirs(TEST_IMAGES_DIR)
   except: pass

   # TODO add new things to pickle file - INCLUDING BATCH SIZE AND LEARNING RATE
   # write all this info to a pickle file in the experiments directory
   exp_info = dict()
   exp_info['LEARNING_RATE'] = LEARNING_RATE
   exp_info['LOSS_METHOD']   = LOSS_METHOD
   exp_info['BATCH_SIZE']    = BATCH_SIZE
   exp_info['L1_WEIGHT']     = L1_WEIGHT
   exp_info['IG_WEIGHT']     = IG_WEIGHT
   exp_info['NETWORK']       = NETWORK
   exp_info['AUGMENT']       = AUGMENT
   exp_info['EPOCHS']        = EPOCHS
   exp_info['DATA']          = DATA
   exp_pkl = open(EXPERIMENT_DIR+'info.pkl', 'wb')
   data = pickle.dumps(exp_info)
   exp_pkl.write(data)
   exp_pkl.close()
   
   print
   print 'LEARNING_RATE: ',LEARNING_RATE
   print 'LOSS_METHOD:   ',LOSS_METHOD
   print 'BATCH_SIZE:    ',BATCH_SIZE
   print 'L1_WEIGHT:     ',L1_WEIGHT
   print 'IG_WEIGHT:     ',IG_WEIGHT
   print 'NETWORK:       ',NETWORK
   print 'AUGMENT:       ',AUGMENT
   print 'EPOCHS:        ',EPOCHS
   print 'DATA:          ',DATA
   print


## Train

In [0]:
%cd "/content/Underwater-Color-Correction/"
!python2 train.py --DATA=underwater_imagenet --EPOCHS=100 --NETWORK=pix2pix --L1_WEIGHT=100 --BATCH_SIZE=32 --IG_WEIGHT=0.0 --LEARNING_RATE=1e-4 --LOSS_METHOD=wgan

# %cd "checkpoints/LOSS_METHOD_wgan/NETWORK_pix2pix/L1_WEIGHT_100.0/IG_WEIGHT_1.0/AUGMENT_0/DATA_underwater_imagenet/"
# !ls
# !cat ./info.pkl

## Run 

In [0]:
# !python2 ./train.py --LEARNING_RATE 0.1
# !mv "underwater_imagenet/" "dataset"
# %cd ..
# !ls -lSs --block-size=M
# !printenv | grep PWD
# !ls -ls

In [0]:
'''

   Evaluation file for only a single image.

'''

import cPickle as pickle
import tensorflow as tf
from scipy import misc
from tqdm import tqdm
import numpy as np
import argparse
import random
import ntpath
import sys
import os
import time
import time
import glob
import cPickle as pickle
from tqdm import tqdm
import cv2

sys.path.insert(0, 'ops/')
sys.path.insert(0, 'nets/')
from tf_ops import *

import data_ops

if __name__ == '__main__':

   if len(sys.argv) < 3:
      print 'You must provide an info.pkl file and an image'
      exit()

   pkl_file = open(sys.argv[1], 'rb')
   a = pickle.load(pkl_file)
   
   LEARNING_RATE = a['LEARNING_RATE']
   LOSS_METHOD   = a['LOSS_METHOD']
   BATCH_SIZE    = a['BATCH_SIZE']
   L1_WEIGHT     = a['L1_WEIGHT']
   IG_WEIGHT     = a['IG_WEIGHT']
   NETWORK       = a['NETWORK']
   AUGMENT       = a['AUGMENT']
   EPOCHS        = a['EPOCHS']
   DATA          = a['DATA']

#    EXPERIMENT_DIR  = 'checkpoints/LOSS_METHOD_'+LOSS_METHOD\
#                      +'/NETWORK_'+NETWORK\
#                      +'/L1_WEIGHT_'+str(L1_WEIGHT)\
#                      +'/IG_WEIGHT_'+str(IG_WEIGHT)\
#                      +'/AUGMENT_'+str(AUGMENT)\
#                      +'/DATA_'+DATA+'/'\

   EXPERIMENT_DIR  = 'experiment_dir/'


   IMAGES_DIR     = EXPERIMENT_DIR+'dataset/test/'

   test_image = sys.argv[2]

   print
   print 'LEARNING_RATE: ',LEARNING_RATE
   print 'LOSS_METHOD:   ',LOSS_METHOD
   print 'BATCH_SIZE:    ',BATCH_SIZE
   print 'L1_WEIGHT:     ',L1_WEIGHT
   print 'IG_WEIGHT:     ',IG_WEIGHT
   print 'NETWORK:       ',NETWORK
   print 'EPOCHS:        ',EPOCHS
   print 'DATA:          ',DATA
   print

   if NETWORK == 'pix2pix': from pix2pix import *
   if NETWORK == 'resnet':  from resnet import *

   # global step that is saved with a model to keep track of how many steps/epochs
   global_step = tf.Variable(0, name='global_step', trainable=False)

   # underwater image
   image_u = tf.placeholder(tf.float32, shape=(1, 256, 256, 3), name='image_u')

   # generated corrected colors
   layers    = netG_encoder(image_u)
   gen_image = netG_decoder(layers)

   saver = tf.train.Saver(max_to_keep=1)

   init = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())
   sess = tf.Session()
   sess.run(init)

   ckpt = tf.train.get_checkpoint_state(EXPERIMENT_DIR)
   if ckpt and ckpt.model_checkpoint_path:
      print "Restoring previous model..."
      try:
         saver.restore(sess, ckpt.model_checkpoint_path)
         print "Model restored"
      except:
         print "Could not restore model"
         pass
   
   step = int(sess.run(global_step))

   img_name = ntpath.basename(test_image)
   img_name = img_name.split('.')[0]

   batch_images = np.empty((1, 256, 256, 3), dtype=np.float32)

   #a_img = misc.imread(test_image).astype('float32')
   a_img = cv2.imread(test_image)
   a_img = cv2.cvtColor(a_img, cv2.COLOR_BGR2RGB)
   a_img = a_img.astype('float32')
   a_img = misc.imresize(a_img, (256, 256, 3))
   a_img = data_ops.preprocess(a_img)
   a_img = np.expand_dims(a_img, 0)
   batch_images[0, ...] = a_img

   gen_images = np.asarray(sess.run(gen_image, feed_dict={image_u:batch_images}))

   misc.imsave('./'+img_name+'_real.png', batch_images[0])
   misc.imsave('./'+img_name+'_gen.png', gen_images[0])
