In [None]:
""" # Commands to download datset directly from colab
from google.colab import files
files.upload() # upload kaggle.json file
!pip install kaggle
!mkdir -p ~/.kaggle/
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d nitishabharathi/scene-classification
!mkdir Dataset
!unzip scene-classification.zip -d Dataset
"""
# Imports
import pickle
import numpy as np 
import os
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Conv2D, Add, Conv2DTranspose, MaxPooling2D, concatenate, Flatten, Dense
from PIL import Image

# Constants
image_size = (256, 256)  # power of two recommended for downsampling
batch_size = 1  # SGD 

# Dataset Function Definitions

def download_dataset_files(folderpath: str, ext: str):
  """
  filepath: str:: full or incomplete path to dataset folder.
  ext: desired extention, for images it probably is .jpg or .png.
  """
  image_files = []
  for folder, _, filenames in os.walk(folderpath):
    for filename in filenames:
      full_path = os.path.join(folder, filename)
      if full_path.endswith(ext):
        image_files.append(full_path)
  return np.array(image_files)
  

def filter_and_print(paths: str):
  """
  Check if path is infact an image that can be opened
  into python.
  """
  valid_paths = []
  for path in paths:
    try:
      image = Image.open(path)
      image.verify()
      valid_paths.append(path)
    except Exception as e:
      print("couldn't open file {}, got error {}".format(path, e))
  return valid_paths

def read_image_tf(file):
  output = tf.io.read_file(file)
  output = tf.image.decode_jpeg(output, channels=3)
  output = tf.image.resize(output, image_size)
  return output

def rgb_to_grayscale(image):
    grayscale_image = tf.image.rgb_to_grayscale(image)
    grayscale_image_normalized = tf.math.divide(grayscale_image, 255)
    return grayscale_image_normalized

def download_dataset(paths):
    dataset = tf.data.Dataset.from_tensor_slices(paths)
    rgb_images = dataset.map(read_image_tf)
    grayscale_images = rgb_images.map(rgb_to_grayscale)
    dataset = tf.data.Dataset.zip((grayscale_images, rgb_images))
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) #  optimize pre-fetching off data to speed up computation
    return dataset

# Model class Definitions

class ImageColorizerBase(object):
  def __init__(self, model):
    self.model = model
    self.history = None
  
  def save_results(self, name):
    self.weights_file = name + '.h5'
    self.history_file = name + '_history'

    try:
      self.model.save(self.weights_file)
    except:
      pass

    try:
      with open(self.history_file, 'wb') as file:
        pickle.dump(self.history.history, file)
    except:
      pass
    
    try:
      files.download(self.weights_file)
    except:
      pass
    
    try:
      files.download(self.history_file)
    except:
      pass

  def fit(self, train, validation, niter=1):

    physical_devices = tf.config.list_physical_devices('GPU')

    if len(physical_devices)>0:
      device="/GPU:0"
    else:
      device="/CPU:0"
    
    with tf.device(device):
      self.history = self.model.fit(train,
              validation_data=validation,
                epochs=niter)

class UnetRegressor(ImageColorizerBase):
  def __init__(self):
    inputs = Input(image_size + (1,))
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
    conv10 = Conv2D(3, (3, 3), activation='relu', padding='same')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])
    model.compile(optimizer = tf.keras.optimizers.Adam(), loss ='MSE')

    super().__init__(model)
  
if __name__ == "__main__":
  # fetch full paths to images
  paths = download_dataset_files('Dataset', '.jpg')
  paths = filter_and_print(paths)

  # split paths to train and test image files
  train_percentage = 0.8
  cutoff = int(train_percentage*len(paths))
  train_paths = paths[:cutoff]
  test_paths = paths[cutoff:]

  # download train and test datasets as tf dataset object
  train_dataset = download_dataset(train_paths)
  test_dataset = download_dataset(test_paths)

  model = UnetRegressor()
  model.fit(train_dataset, test_dataset, niter=30)
  