In [1]:
#loading .dll files for GPU b/c for some reason TF isn't finding them through PATH (Windows 10)
#    -assumes that Cuda Toolkit v10.0.130 is installed in 'C:\Program Files\NVIDIA GPU Computing Toolkit\'
#    -assumes that nvcuda.dll is in 'C:\System32\'
import ctypes
hllDll = ctypes.WinDLL("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0\\bin\\cudart64_100.dll")
tllDll = ctypes.WinDLL("C:\\tools\\cuda\\bin\\cudnn64_7.dll")
cllDll = ctypes.WinDLL("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0\\bin\\cublas64_100.dll")
ullDll = ctypes.WinDLL("C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0\\extras\\CUPTI\\libx64\\cupti64_100.dll")
nllDll = ctypes.WinDLL("C:\\Windows\\System32\\nvcuda.dll")

import tensorflow as tf
from tensorflow import keras
import os
import cv2
from conv import MOE
import random
from data_gen import read_data
import numpy as np
from datetime import datetime
import io
import csv
import ipykernel
import collections
import six
from trainvisout import train_schedule
from tensorflow.keras import backend as K

In [2]:
#image batch shape is (16,128,128,3)
batch_size = 256
def get_params():
    data_start = input("Which dataset to use? (0-318) ")
    data_end = input("Which dataset to use? (0-318) ")
    test_start = input("Which testset to use? (0-35) ")
    test_end = input("Which testset to use? (0-35) ")
    num_experts = input("Number of total experts: ")
    top_k = input("Top K: ")
    optimizer = input("Optimizer (SGD, Adam): ")
    loss = input("Loss (psnr, ssim, mssiml1): ")
    scale = input("Scale (2,3,4): ")
    epochs = input("Number of epochs: ")
    name = input("Name of run: ")
    behavior = input("'train', 'evalu', 'predict', or 'all'? ")
    return {"data_start":data_start,
            "data_end":data_end,
            "test_start":test_start,
            "test_end":test_end,
            "num_experts":num_experts,
            "top_k":top_k,
            "optimizer":optimizer,
            "loss":loss,
            "name":name,
            "epochs":epochs,
            "scale":scale,
            "behavior":behavior}


def ssim_loss(image,label):
    #negative of SSIM b/c TF will try to minimize loss
    return tf.reduce_mean(tf.image.ssim(image,label,1.)) * -1

def mssim_loss(image,label):
    #negative of SSIM b/c TF will try to minimize loss
    return tf.reduce_mean(tf.image.ssim_multiscale(image,label,1.)) * -1

def psnr_loss(image,label):
    #negative of PSNR b/c TF will try to minimize loss
    return tf.reduce_mean(tf.image.psnr(image,label,max_val=1.0)) * -1

def psnr_loss1(image,label):
    def log10(num):
        num = tf.dtypes.cast(num,tf.float32)
        numer = tf.math.log(num)
        denom = tf.math.log(tf.constant(10,dtype=numer.dtype))
        return numer/denom
    mse = tf.reduce_mean(tf.math.squared_difference(image,label))
    #mse = tf.reduce_mean(tf.math.squared_difference(image,label),[-3,-2,-1])
    
    psnr = 20 * log10(1.0) - 10 * log10(mse)
    
    #psnr = tf.math.subtract(20 * tf.math.log(1.0) / tf.math.log(10.0),np.float32(10 / np.log(10)) * tf.math.log(mse))

    return psnr

def mse_loss(image,label):
    return tf.reduce_mean(tf.keras.losses.MSE(image,label))

def load_loss(model_load,full_load):
    model_load = tf.dtypes.cast(model_load,tf.float32)
    full_load = tf.dtypes.cast(full_load,tf.float32)
    model_sum = tf.reduce_mean(model_load,axis=0)
    perfect_sum = tf.reduce_mean(full_load,axis=0)
    return tf.keras.losses.MSE(model_sum,perfect_sum)

def date_time():
    return str(datetime.now().time())

def mssiml1_loss(image,label):
    """see https://arxiv.org/abs/1511.08861"""
    
    def _fspecial_gauss(size, sigma, img1, img2):
        """Function to mimic the 'fspecial' gaussian MATLAB function.
          Copied directly from TF image_ops, low-level ops replaced with functional counterparts
        """
        size = tf.convert_to_tensor(size, tf.int32)
        sigma = tf.convert_to_tensor(sigma)

        coords = tf.dtypes.cast(tf.range(size), sigma.dtype)
        coords -= tf.dtypes.cast(size - 1, sigma.dtype) / 2.0

        g = tf.math.square(coords)
        g *= -0.5 / tf.math.square(sigma)

        g = tf.reshape(g, shape=[1, -1]) + tf.reshape(g, shape=[-1, 1])
        g = tf.reshape(g, shape=[1, -1])  # For tf.nn.softmax().
        g = tf.nn.softmax(g)
        return tf.reshape(g, shape=[size, size, 1, 1])

    def reducer(x,kernel):
        """Copied directly from TF image_ops, low-level ops replaced with functional counterparts"""
        shape = tf.shape(x)
        x = tf.reshape(x, shape=tf.concat([[-1], shape[-3:]], 0))
        y = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='SAME')
        return tf.reshape(
            y, tf.concat([shape[:-3], tf.shape(y)[1:]], 0))
    
    alpha = 0.84
    shape1, shape2 = tf.shape_n([image, label])
    
    kernel = _fspecial_gauss(11,0.2,shape1,shape2)
    kernel = tf.tile(kernel, multiples=[1, 1, shape1[-1], 1])
    l1 = tf.reduce_mean(tf.multiply(tf.math.abs(image-label),reducer(image,kernel)))
    mssim = tf.reduce_mean(tf.image.ssim_multiscale(image,label,1,power_factors=(0.0448,)))
    mixed = alpha*mssim - (1-alpha)*l1
    return -1 * mixed

def load_weights(model_dir):
    moe.train_on_batch(x=tf.ones((moe.batch_size,32,32,3)),y=[tf.ones((moe.batch_size,16,16,3)),tf.ones((moe.batch_size,num_experts))])
    filelist = sorted(filter(os.path.isfile, os.listdir(model_dir)), key=os.path.getmtime)
    filelist.reverse()
    for filename in filelist:
        if filename[-6:]==".index":
            moe.load_weights(filename[:-6])
            print("Loaded weights: ",filename[:-6])
            break
    print("Weights not found")

class TrainCallback(tf.keras.callbacks.CSVLogger):
    def __init__(self,**kwargs):
        self.epoch = 0
        super(TrainCallback,self).__init__(**kwargs)
    
    def on_train_batch_end(self,batch,logs=None):
        if batch==0:
            self.epoch += 1
        epoch = self.epoch
        logs = logs or {}

        def handle_value(k):
            is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
            if isinstance(k, six.string_types):
                return k
            elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
                return '"[%s]"' % (', '.join(map(str, k)))
            else:
                return k

        if self.keys is None:
            self.keys = sorted(logs.keys())
        
        if self.model.stop_training:
            # We set NA so that csv parsers do not fail for this last epoch.
            logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])

        if not self.writer:

            class CustomDialect(csv.excel):
                delimiter = self.sep

            fieldnames = ['epoch'] + self.keys

            self.writer = csv.DictWriter(
                self.csv_file,
                fieldnames=fieldnames,
                dialect=CustomDialect)
            if self.append_header:
                self.writer.writeheader()

        row_dict = collections.OrderedDict({'epoch': epoch})
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
        self.writer.writerow(row_dict)
        self.csv_file.flush()
    def on_epoch_end(self,epoch,logs=None):
        pass

class TestCallback(tf.keras.callbacks.CSVLogger):
    def __init__(self,**kwargs):
        super(TestCallback,self).__init__(**kwargs)
        self.epoch = 0
    def on_test_begin(self,logs=None):
        super(TestCallback,self).on_train_begin(logs=logs)
    def on_test_batch_end(self,batch,logs=None):
        if batch==0:
            self.epoch += 1
        super(TestCallback,self).on_epoch_end(self.epoch,logs=logs)
    def on_test_end(self,logs=None):
        super(TestCallback,self).on_train_end(logs=None)


### build model ###
for params in train_schedule:
    K.clear_session()
    
    #params = get_params()

    num_experts = int(params["num_experts"])
    top_k = int(params["top_k"])
    optimizer = params["optimizer"]
    if optimizer=='Adam':
        optimizer = tf.keras.optimizers.Adam()
    elif optimizer=='SGD':
        optimizer = tf.keras.optimizers.SGD()

    x = params["behavior"]

    behavior = {}
    if (x=='all'):
        behavior["train"] = True
        behavior["evalu"] = True
        behavior["predict"] = True
        behavior["evalu_1"] = False
    elif (x=='train'):
        behavior["train"] = True
        behavior["evalu"] = False
        behavior["predict"] = False
        behavior["evalu_1"] = False
    elif (x=='evalu'):
        behavior["train"] = False
        behavior["evalu"] = True
        behavior["predict"] = False
        behavior["evalu_1"] = False
    elif (x=='predict'):
        behavior["train"] = False
        behavior["evalu"] = False
        behavior["predict"] = True
        behavior["evalu_1"] = False
    elif (x=='evalu_1'):
        behavior["train"] = False
        behavior["evalu"] = False
        behavior["predict"] = False
        behavior["evalu_1"] = True

    loss = params["loss"]
    if loss=="ssim":
        loss = ssim_loss
    elif loss=="psnr":
        loss = psnr_loss
    elif loss=="mse":
        loss = keras.losses.MSE
    elif loss=="mssim":
        loss = mssim_loss
    elif loss=="mssiml1":
        loss = mssiml1_loss
    elif loss=="ssiml1":
        loss = ssiml1_loss

    name = params["name"]
    epochs = int(params["epochs"])
    scale = int(params["scale"])

    batch_size = num_experts * 32
    data_start = int(params["data_start"])
    data_end = int(params["data_end"]) + 1
    test_start = int(params["test_start"])
    test_end = int(params["test_end"]) + 1

    train_data_dir = "D:/TestData/train-x" + str(scale)+"/"
    test_data_dir = "D:/TestData/test-x" + str(scale)+"/"
    assert os.path.exists(train_data_dir), "train_data_dir does not exist"
    assert os.path.exists(test_data_dir), "test_data_dir does not exist"

    model_dir = "D:/SR-Research/models-N256/" + name
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    os.chdir(model_dir)

    moe = MOE(num_experts,top_k,batch_size,train_data_dir)
    if loss==ssim_loss or loss==mssim_loss or loss==mssiml1_loss or loss==ssiml1_loss:
        moe.compile(optimizer=optimizer,
                    loss=[loss,load_loss],
                    loss_weights=[1.0,128.0],
                    metrics=[[ssim_loss,psnr_loss1,mse_loss],[]])
    elif loss==psnr_loss:
        moe.compile(optimizer=optimizer,loss=[loss,load_loss],loss_weights=[0.8,0.2])
    elif loss==keras.losses.MSE:
        moe.compile(optimizer=optimizer,loss=[loss,load_loss],loss_weights=[0.8,0.2])

    #init = tf.zeros((batch_size,32,32,3))
    #init = tf.dtypes.cast(init,tf.float32)
    #moe.predict(init)
    moe.build((batch_size,32,32,3))
    if name in os.listdir("D:\SR-Research\models-N256\\"):
            load_weights(model_dir)
    #if behavior["train"]:
    #    while True:
    #        a = input("First time training? ")
    #        if (a=='y'):
    #            break
    #        elif (a=='n'):
    #            load_weights()
    #            break
    #else:
    #    load_weights()
    
    if behavior["train"]:
        print(date_time() + ", Training start: ",params["name"])
        ##code for training on data
        csv_logger = TrainCallback(filename='training.log',append=True)

        train_sets = os.listdir(train_data_dir)
        for i in range(epochs):
            for datanum in range(data_start,data_end):
                train_set = train_sets[datanum]
                images, labels = read_data(train_data_dir + train_set)

                #cuts off images off the end of the batch so that it is divisible by 16
                if (len(images) % batch_size != 0):
                    images = images[:-(len(images)%batch_size)]
                    labels = labels[:-(len(labels)%batch_size)]

                labels = np.array(tf.image.central_crop(labels,1/2))

                load = np.full((len(images),num_experts),(top_k/num_experts))

                moe.fit(x=images,y=[labels,load],batch_size=batch_size,epochs=1,callbacks=[csv_logger],verbose=0)

                #saves weights each training session
                if datanum%25 == 0:
                    moe.save_weights("weights-E"+str(i)+"D"+str(datanum), save_format="tf")
        print(date_time() + ", Training finish: ",params["name"])

    if behavior["evalu"]:
        print(date_time() + ", Eval start: ",params["name"])
        ##code for evaluating on test sets
        csv_logger_test = TestCallback(filename='testing.log',append=True)
        test_sets = os.listdir(test_data_dir)
        for testnum in range(test_start,test_end):
            test_set = test_sets[testnum]
            images, labels = read_data(test_data_dir + test_set)

            #cuts off images off the end of the batch so that it is divisible by 16
            if (len(images) % batch_size != 0):
                images = images[:-(len(images)%batch_size)]
                labels = labels[:-(len(labels)%batch_size)]

            labels = np.array(tf.image.central_crop(labels,1/2))

            load = np.full((len(images),num_experts),(top_k/num_experts))

            moe.evaluate(x=images,y=[labels,load],batch_size=batch_size,callbacks=[csv_logger_test],verbose=0)
        print(date_time() + ", Eval finished: ",params["name"])
    if behavior["evalu_1"]:
        K.clear_session()
        csv_logger_test = TestCallback(filename='testing_1.log',append=True)
        inputs = keras.Input(shape=(batch_size,16,16,3))
        model = keras.Model(inputs=inputs,outputs=inputs)
        model.compile(optimizer=optimizer,
                    loss=loss,
                    metrics=[ssim_loss,psnr_loss1,mse_loss])
        test_sets = os.listdir(test_data_dir)
        for testnum in range(test_start,test_end):
            test_set = test_sets[testnum]
            images,labels = read_data(test_data_dir + test_set)
            images = np.array(tf.image.central_crop(images,1/2))
            labels = np.array(tf.image.central_crop(labels,1/2))
            model.evaluate(x=images,y=labels,batch_size=batch_size,callbacks=[csv_logger_test],verbose=0)
            
###update to 128
    if behavior["predict"]:
        train_sets = os.listdir(train_data_dir)
        def vis_out(num=0):
            images, labels = read_data(train_data_dir + train_sets[num])
            images = images[-1*batch_size:]
            labels = labels[-1*batch_size:]
            predictions, _ = moe.predict(images,batch_size=batch_size,steps=1)
            if not os.path.exists("visual_outputs/small/"):
                os.makedirs("visual_outputs/small/")
            for num in range(batch_size):
                img = np.array(tf.image.central_crop(images[num] * 255,1/2))
                label = np.array(tf.image.central_crop(labels[num] * 255,1/2))
                pred = predictions[num] * 255
                #cv2.imwrite("visual_outputs/small/image"+str(num)+".png",img)
                #cv2.imwrite("visual_outputs/small/label"+str(num)+".png",label)
                cv2.imwrite("visual_outputs/small/prediction"+str(num)+".png",pred)
        def rebuild_img(scale=2,nums=[2,6,51,336,619,150,449,427]): #add 2 back in
            for num in nums:
                num = str(num).zfill(4)
                hrpath = "F:\div2k\DIV2K_train_HR\\"
                hrimg = cv2.imread(hrpath + num + ".png")
                #hrimg = cv2.imread("F:/div2k/DIV2K_train_LR_bicubic/X4/0002x4.png") #testing image
                height, width, channels = hrimg.shape
                
                if height % 256 != 0:
                    height -= height % 256
                height += 16
                if width % 256 != 0:
                    width -= width % 256
                width += 16
                hrimg = hrimg[0:height,0:width]

                height, width, channels = hrimg.shape

                lrimg = cv2.cvtColor(hrimg,cv2.COLOR_BGR2RGB)
                lrimg = np.ndarray.astype(lrimg,'float32')
                lrimg /= 255.
                lrimg = cv2.resize(lrimg,(int(width/2.),int(height/2.)))
                lrimg = cv2.resize(lrimg,(width,height),interpolation=cv2.INTER_CUBIC)
                predlist = []
                patches = np.empty((0,32,32,3))
                for h in range(1,height//16):
                    for w in range(1,width//16):
                        patch = lrimg[(h-1)*16:(h+1)*16,(w-1)*16:(w+1)*16] #32x32 patch
                        patch = np.expand_dims(patch,axis=0)
                        predlist.append(patch)
                        #patches = np.concatenate((patches,patch),axis=0)
                    #print(h," done")
                    #preds = moe.predict(patches,batch_size=16)[0] #shape (16,16,3)
                    #across = np.empty((16,0,3))
                    #for i in range(len(preds)):
                    #    pred = preds[i:i+1]
                    #    pred = pred[0]
                    #    across = np.concatenate((across,pred),axis=1)
                    #newimg = np.concatenate((newimg,across),axis=0)
                    #print(str(round(newimg.shape[0]/(size-32),2)) + "%")
                #patches in shape (num_patches, 32, 32 3)
                patches = np.concatenate(predlist,axis=0)
                counter = len(patches)
                hx = height//16 - 1
                wx = width//16 - 1
                
                if (counter%batch_size)==0:
                    extras = np.empty((batch_size,32,32,3))
                    num_extras = len(extras)
                else:
                    extras = patches[-1*(counter%batch_size):]
                    num_extras = len(extras)
                    extras = np.concatenate((extras,np.tile(np.empty((1,32,32,3)),(batch_size-num_extras,1,1,1))),axis=0)
                    patches = patches[:-1*(counter%batch_size)]
                
                
                patch_pred = moe.predict(patches,batch_size=batch_size)[0]
                extra_pred = moe.predict(extras,batch_size=batch_size)[0]
                extra_pred = extra_pred[:-1*(batch_size-num_extras)]
                
                preds = np.concatenate((patch_pred,extra_pred),axis=0)
                #shape (wx*hx,16,16,32)
                newimg = None
                for h in range(hx):
                    across = across = np.empty((16,0,3))
                    for w in range(wx):
                        
                        x = preds[h * wx + w:h * wx + w + 1]
                        
                        x = x[0]
                        
                        across = np.concatenate((across,x),axis=1)
                    
                    try:
                        if newimg==None:
                            newimg = np.empty((0,across.shape[1],3),dtype=np.float32)
                    except:
                        pass
                    #print(across.shape)
                    #print(newimg.shape)
                    newimg = np.concatenate((newimg,across),axis=0)

                newimg *= 255
                newimg = newimg.astype(np.float32)
                lrimg = lrimg.astype(np.float32)
                newimg = cv2.cvtColor(newimg,cv2.COLOR_BGR2RGB)

                lrimg = cv2.cvtColor(lrimg,cv2.COLOR_BGR2RGB)
                lrimg *= 255
                if not os.path.exists("visual_outputs/large/"):
                    os.makedirs("visual_outputs/large/")
                assert cv2.imwrite("visual_outputs/large/image"+str(num).zfill(4)+".png",lrimg)
                assert cv2.imwrite("visual_outputs/large/label"+str(num).zfill(4)+".png",hrimg)
                assert cv2.imwrite("visual_outputs/large/prediction"+str(num).zfill(4)+".png",newimg)
        vis_out()
        #rebuild_img(scale=scale)
    

Instructions for updating:
This op will be removed after the deprecation date. Please switch to tf.sets.difference().
Instructions for updating:
Use tf.identity instead.
Loaded weights:  weights-E3D225
Weights not found
Loaded weights:  weights-E3D225
Weights not found
Loaded weights:  weights-E3D225
Weights not found
Loaded weights:  weights-E3D225
Weights not found
Loaded weights:  weights-E3D225
Weights not found
