In [None]:
# imports
# -------

import numpy as np
import random
import copy
import math
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import csv
import cv2
import h5py
import time
from tqdm import tqdm
from multiprocessing import Pool
from multiprocessing.dummy import Pool as ThreadPool
import os
from os import listdir
from os.path import isfile, join
from PIL import Image
import pandas as pd
import multiprocessing
from scipy.ndimage import label
from skimage import measure
import json

# torch related imports
# ---------------------
import torch
from torch.autograd import Variable
import torch.autograd
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torchvision.utils import save_image
import torch.optim as optim

# local settings
# --------------
import warnings
import sys
if not sys.warnoptions:
    warnings.simplefilter("ignore")
    
    
# storage related imports
# -----------------------
import tempfile
from tempfile import TemporaryFile
from google.cloud import storage

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = '/home/venkateshmadhava/datasets/ven-ml-project-387fdf3f596f.json'

%matplotlib inline
%env JOBLIB_TEMP_FOLDER=/tmp


# Code

### Helper Functions

In [None]:
# getting file from google cloud storage
# ---------------------------------------

def get_file_from_google_storage(file_name):
    
    '''
    
    1. takes an input google cloud storage file name
    2. downloads to temp file
    3. returns local temp file path
    
    '''
    
    # 0. initialising bucket
    # ----------------------
    print('0. initialising bucket..')
    bucket_name = 'gpu_datatset_bucket'
    storage_client = storage.Client()
    bucket = storage_client.get_bucket(bucket_name)
    
    # 1. retrieveing blob
    # -------------------
    print('1. retrieving blob..')
    blob = bucket.blob(file_name)
    
    # 2. downloading blob to temp file
    # --------------------------------
    print('2. downloading blob to temp file, this may take a while..')
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as gcs_tempfile:
        blob.download_to_filename(gcs_tempfile.name)
        
    # 3. final return
    # ---------------
    print('Done.')
    return gcs_tempfile.name

In [None]:
# Pool function to create a dataset from input folder
# ---------------------------------------------------

def create_dataset_from_folder_all(infolder,n_h,n_w):
    
    # 0. initialisations
    # ------------------
    image_list = [f for f in listdir(infolder) if isfile(join(infolder, f)) and '.jpg' in f.lower()]
    assert len(image_list) > 0, 'No images found in the folder'
    xout = np.zeros((len(image_list),n_h,n_w,3), dtype='uint8')
    
    # 1. building args
    # ----------------
    all_args = []
    for i in range(xout.shape[0]):
        all_args.append((i,xout,infolder,image_list,n_h,n_w))
    
    # 2. calling resize function across multiprocessing pool
    # ------------------------------------------------------
    pool = ThreadPool(5)
    pool.starmap(create_dataset_from_folder_single, all_args)
    print('Done creating a dataset with ' + str(xout.shape[0]) + ' images.')
    
    return xout

    
# FUNCTION 2
# GENERIC FUNCTION - to resize a single image
# ------------------------------------------
def create_dataset_from_folder_single(i,xout,infolder,image_list,n_h,n_w):
    
    # snippet
    # -------
    name = image_list[i]
    img = cv2.imread(join(infolder, name))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (n_w,n_h))
    xout[i] = img



In [None]:
# Pool function to resize images
# ------------------------------

def resize_all(ximgs,n_h,n_w):
    
    # 0. initialisations
    # ------------------
    c = ximgs.shape[3]
    x_images_resized = np.zeros((ximgs.shape[0],new_h,new_w,c), dtype='uint8')
    
    # 1. building args
    # ----------------
    all_args = []
    for i in range(ximgs.shape[0]):
        all_args.append((i,ximgs,x_images_resized,n_h,n_w))
    
    # 2. calling resize function across multiprocessing pool
    # ------------------------------------------------------
    pool = ThreadPool(5)
    pool.starmap(resize_image_single, all_args)
    print('Done resizing ' + str(ximgs.shape[0]) + ' images.')
    
    return x_images_resized

    
# FUNCTION 2
# GENERIC FUNCTION - to resize a single image
# ------------------------------------------
def resize_image_single(i,x_in,x_out,new_h,new_w):
    
    # simple code
    # -----------
    img = x_in[i]
    img = cv2.resize(img, (new_w,new_h))
    x_out[i] = img
   

In [None]:
# SIMPLE FUNTION TO CONVERT RGB TO GRAYSCALE
# -------------------------------------------
def rgb2gray(x):
    

    x[:,:,:,0] = x[:,:,:,0] * 0.2989
    x[:,:,:,1] = x[:,:,:,0] * 0.5870
    x[:,:,:,2] = x[:,:,:,0] * 0.1140
    xout = np.sum(x,axis = 3)
    

    #r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    #gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return xout

In [None]:
# function to set all cats list
# -----------------------------

''' THIS HAS TO BE CONSISTENT BETWEEN TRAIN, VAL AND TEST SETS '''

def build_class_list(jsonin_url):
    
    
    # 0. initialisations
    # ------------------
    all_classes = []    
    
    
    # 1. import JSON file and load it - loads into a list of dicts! easy to parse
    # ---------------------------------------------------------------------------
    print('1. loading jsons..')
    with open(jsonin_url) as f:
        data = json.load(f)

    # sanity print
    ##
    print('2. total jsons loaded: ' + str(len(data)))
    

    # 2. looping
    # ----------
    print('3. looping and build classes set..')
    for each in data:
        for each_l in each['labels']:
            if 'box2d' in each_l.keys():
                all_classes.append(each_l['category'])


    # 3. creating a set of categories
    # --------------------------------
    all_cats = list(set(all_classes))
    print('4. there are ' + str(len(all_cats)) + ' classes of objects in this dataset with bounding box - \n*****')
    print(all_cats)
    del data
    
    # 4. final return
    # ---------------
    return all_cats



In [None]:
# main function that returns dict for building h5 dataset
# -------------------------------------------------------

def build_datasets(all_classes,jsonin_url,img_folder_in,n_h,n_w):
    
    # 0. initialisations
    # ------------------
    global img_folder
    img_folder = img_folder_in
    global json_dict
    json_dict = {}
    no_classes = len(all_classes)
    counter = 0
    cutoff = 7500
    
    # 1. importing json
    # -----------------
    print('1. loading jsons..')
    with open(jsonin_url) as f:
        data = json.load(f)
    print('2. total jsons loaded: ' + str(len(data)))
    
    
    # 2. parsing json into a custom dict
    # ----------------------------------
    print('3. building custom dict for easy parsing..')
    for each in data:
        boxflag = 0
        for each_l in each['labels']:
            if 'box2d' in each_l.keys():
                boxflag = 1

        # creating new json_dict
        # ----------------------
        if boxflag == 1:
            json_dict[each['name']] = each
    
    # 2.1 deleting data
    # -----------------
    del data
    
    # 3. getting list of images and filtering them
    # --------------------------------------------
    img_list_all = [imgs for imgs in os.listdir(img_folder) if '.jpg' in imgs.lower()]
    
    
    # 4. setting up global variables to use with pool function
    # --------------------------------------------------------
    global img_list
    img_list = [imgs for imgs in img_list_all if imgs in list(json_dict.keys())]
    m = len(img_list)
    
    global x_source
    x_source = np.zeros((m,n_h,n_w,3), dtype = 'uint8')
    
    global x_target
    x_target = np.zeros((m,n_h,n_w,len(all_classes)), dtype = 'uint8')
    
    global new_h
    new_h = n_h
    
    global new_w
    new_w = n_w
    
    global gb_all_classes
    gb_all_classes = all_classes
    
    
    # 5. start of threaded function
    # -----------------------------
    print('4. starting threading function..')
    pool = ThreadPool(5)
    pool.map(build_datasets_single, list(range(len(img_list))))
    print('Done with ' + str(x_source.shape[0]) + ' images. Access them at global x_source amd x_target.')
    
    
    
# a single function for the above
# -------------------------------
    
def build_datasets_single(i):
    
    # 0. getting all global variables
    # -------------------------------
    global img_list
    global x_source
    global x_target
    global new_h
    global new_w
    global gb_all_classes
    global json_dict
    global img_folder
    
    
    # 1. local initialisations
    # -------------------------
    imgname = img_list[i]
    
    # 2. main ops
    # -----------
    img = cv2.imread(join(img_folder, imgname))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # initialising target image for each individual image since each image is of different size
    # -----------------------------------------------------------------------------------------
    h,w,_ = img.shape
    curr_target = (np.ones((h,w,len(gb_all_classes)))*255).astype('uint8')/255

    # setting 0 in the target channel
    # -------------------------------
    for each_label in json_dict[imgname]['labels']:

        if 'box2d' in each_label.keys():

            # getting box co-ordinates
            # ------------------------
            curr_cat = each_label['category']
            curr_index = gb_all_classes.index(curr_cat)
            curr_x1 = int(each_label['box2d']['x1'])
            curr_x2 = int(each_label['box2d']['x2'])
            curr_y1 = int(each_label['box2d']['y1'])
            curr_y2 = int(each_label['box2d']['y2'])

            # setting target channel
            # ----------------------
            curr_target[curr_y1:curr_y2,curr_x1:curr_x2,curr_index] = 0

    # resizing images before saving to dict
    # -------------------------------------
    img = cv2.resize(img, (new_w,new_h))
    curr_target = cv2.resize(curr_target, (new_w,new_h))

    # some correction
    # ---------------
    curr_target[curr_target < 0.75] = 0
    curr_target[curr_target >= 0.75] = 1

    # reshaping
    # ---------
    x_source[i] = img
    x_target[i] = curr_target
    
    


# COCO dataset build functions

In [None]:
# need a way to create datasets from input dict after resizing
###

def create_single_dataset_from_indict(d,n_h,n_w):
    
    ''' creates a numpy dataset from dict where dict[key] = img '''
    
    # 0. initialisations
    # ------------------
    m = len(d)
    keys_list = list(d.keys())
    c = d[list(d.keys())[0]].shape[2]
    xout = np.zeros((m,n_h,n_w,c), dtype = 'uint8')
    
    # 1. using a standard for loop
    # ----------------------------
    for i in range(m):
        
        # resize ops
        # ----------
        xout[i] = cv2.resize(d[keys_list[i]], (n_w,n_h))
        print('Done with image ' + str(i+1) + ' of around ' + str(m) + '..', end = '\r')
    
    
    # 2. final return
    # ---------------
    return xout

In [None]:
# specific function to list cats
# ------------------------------

def run_stats_coco(data):
    
    # 0. initialisations
    # ------------------
    d = {}
    super_d = {}
    super_d_list = {}
    cats = {}
    sorted_list = []
    sorted_super_d_list = []
    
    # 1.0 build simple cats dict
    # --------------------------
    for each in data['categories']:
        cats[each['id']] = each['name']
        try:
            super_d[each['supercategory']].append(each['id'])
        except:
            super_d[each['supercategory']] = []
            super_d[each['supercategory']].append(each['id'])
    
    
    # 1. looping thru data annotations
    # --------------------------------
    for each in data['annotations']:
        
        # super cat wise count
        # --------------------
        for keys in super_d:
            if each['category_id'] in super_d[keys]:
                try:
                    super_d_list[keys].append(each['image_id'])
                except:
                    super_d_list[keys] = []
                    super_d_list[keys].append(each['image_id'])
                
                # a final correction
                # ------------------
                super_d_list[keys] = list(set(super_d_list[keys]))
                
            
        # cat wise count
        # --------------
        try:
            d[each['category_id']].append(each['image_id'])
            
        except:
            d[each['category_id']] = []
            d[each['category_id']].append(each['image_id'])
        
        # final correcttion
        # -----------------
        d[each['category_id']] = list(set(d[each['category_id']]))
            
        
    
    # 2. for printing results
    # -----------------------
    for keys in d:
        sorted_list.append((len(d[keys]),cats[keys]))
        
    for keys in super_d_list:
        sorted_super_d_list.append((len(super_d_list[keys]), keys))
    
    
    # 3. final prints
    # ---------------
    print('Printing cat wise count:')
    sorted_list.sort()
    for each in list(reversed(sorted_list)):
        print(each)
        
    print('\n\nPrinting super cat wise count:')
    sorted_super_d_list.sort()
    for each in list(reversed(sorted_super_d_list)):
        print(each)


In [None]:
# COCO training set creation
##

def build_coco_dataset_all(json_url, img_folder, mode, interested_cats):
    
    ''' 
    
    1. this function requires the interested categories to be provided by the user
    
    2. default interested cats that can be used - 
    ['person','chair','car','dining table','cup','bottle','bowl']
    
    3. default interested super cats that can be used - 
    ['person','furniture','vehicle','animal','food','electronic']
    
    '''
    
    # 0. loading json
    # ---------------
    with open(json_url) as f:
        data = json.load(f)
    
    # 0.1 initialisations
    # -------------------
    no_cats = len(interested_cats)
    xcoco_src_dict = {}
    xcoco_tgt_dict = {}
    img_list_all = [imgs for imgs in os.listdir(img_folder) if '.jpg' in imgs.lower()]
    cats_dict = {}
    super_d = {}
    
    # 0.2 setting up cats and super cat dict
    # --------------------------------------
    for each in data['categories']:
        cats_dict[each['id']] = each['name']
        
        try:
            super_d[each['supercategory']].append(each['id'])
        except:
            super_d[each['supercategory']] = []
            super_d[each['supercategory']].append(each['id'])
        
        
    # 1. main for loop
    # ----------------
    for i in range(len(data['annotations'])):
        
        print('At annotation ' + str(i+1) + ' of around ' + str(len(data['annotations'])) + '...', end='\r')

        # other initialisations
        # ---------------------
        curr_name = str(data['annotations'][i]['image_id'])
        if mode == 'cats':
            
            # mode is cats
            # ------------
            curr_cat_name = cats_dict[data['annotations'][i]['category_id']]
            
        else:
            
            # mode is super cats
            # ------------------
            for keys in super_d:
                if data['annotations'][i]['category_id'] in super_d[keys]:
                    curr_cat_name = keys
                    break

        # 1. main check point
        # -------------------
        if curr_cat_name in interested_cats:

            # setting current channel
            # -----------------------
            curr_channel = interested_cats.index(curr_cat_name)

            # setting local variables
            # -----------------------
            curr_w1 = int(data['annotations'][i]['bbox'][0])
            curr_h1 = int(data['annotations'][i]['bbox'][1])

            curr_w2 = curr_w1 + int(data['annotations'][i]['bbox'][2])
            curr_h2 = curr_h1 + int(data['annotations'][i]['bbox'][3])

            # 1.1 main ops - setting initial images
            # ------------------------------------
            if curr_name in xcoco_src_dict.keys():

                # setting target channel in the if clause
                # ---------------------------------------
                xcoco_tgt_dict[curr_name][curr_h1:curr_h2,curr_w1:curr_w2,curr_channel] = 0

            else:

                # need to create src and tgt images
                # ---------------------------------
                temp_imgname = [each_x for each_x in img_list_all if curr_name in each_x]

                if len(temp_imgname) == 1:

                    # processing only if one image from the folder matches with current annotation
                    # ----------------------------------------------------------------------------
                    img = cv2.imread(join(img_folder, temp_imgname[0]))
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    xcoco_src_dict[curr_name] = img

                    # initialising curr target
                    # ------------------------
                    h,w,_ = img.shape
                    curr_target = (np.ones((h,w,no_cats))*255).astype('uint8')/255
                    curr_target[curr_h1:curr_h2,curr_w1:curr_w2,curr_channel] = 0
                    xcoco_tgt_dict[curr_name] = curr_target

    
    # final return
    # ------------
    return xcoco_src_dict,xcoco_tgt_dict

# NN related code

In [None]:
# Function to check accuracy on regression model
# ----------------------------------------------

def check_regression_accuracy(x_in,model_ae,direct_mode,model_num_face_net,y_target_in,proximity_percent):
    
    # 0. Some initialisations
    # -----------------------
    if direct_mode == False:
        latent_xin = chunk_pass(x_in,model_ae,True,use_cuda,1)
        pred_xin = chunk_pass(latent_xin,model_num_face_net,False,use_cuda,1)
        y_out_np = pred_xin.data.cpu().numpy()
    else:
        pred_xin = chunk_pass(x_in,model_num_face_net,False,use_cuda,1)
        y_out_np = pred_xin.data.cpu().numpy()
        
    y_target_np = y_target_in.reshape(y_out_np.shape)
    
    # 1. similarity ops
    # -----------------
    similarity = np.minimum(y_out_np,y_target_np)/np.maximum(y_out_np,y_target_np)
    sim_thresheld = (similarity >= proximity_percent).astype('int')
    total_got_right = np.sum(sim_thresheld)
    avg_distance = np.sum(sim_thresheld*similarity)/total_got_right
    percent_got_right = round((total_got_right/x_in.size()[0])*100,2)
    
    # info print
    # ----------
    print('The accuracy for given distance threshold is ' + str(percent_got_right) + ' %.')
    
    # 2. final return
    # ---------------
    return y_out_np, total_got_right, avg_distance
    
    
    

In [None]:
# a simple function to generate output from latent
# ------------------------------------------------

def generate_output(xin,model,start_ind,end_ind,print_images,use_cuda):
    
    
    # 1. generating output
    # --------------------
    xout = chunk_pass(xin[start_ind:end_ind],model.eval(),False,use_cuda,1)
    
    # images dataset
    # --------------
    xout_gen = to_numpy_image(xout.cpu().data)#.astype('uint8')
    xout_orig = to_numpy_image(xin[start_ind:end_ind].cpu().data).astype('uint8')
        
 
    # 4. priniting images
    # -------------------
    if print_images == True:
        for i in range(xout_orig.shape[0]):
            print('Example ' + str(i) + '..')
            print('----------------------')
            print('Original - ')
            plt.figure(figsize=(5,5))
            plt.imshow(xout_orig[i])
            plt.show()
            print('Generated - ')
            plt.figure(figsize=(5,5))
            plt.imshow(xout_gen[i])
            plt.show()
            print('\n----------------\n')
    
    # returns
    # -------
    return xout_orig, xout_gen


In [None]:
# Super helpful chunker function that returns seq chunks correctly sized even at ends
# -----------------------------------------------------------------------------------
# GENERIC function to calculate conv outsize
# ------------------------------------------
def outsize_conv(n_H,n_W,f,s,pad):
    
    h = ((n_H - f + (2*pad))/s) + 1
    w = ((n_W - f + (2*pad))/s) + 1
    return h,w
    
    
# GENERIC function to calculate upconv outsize
# --------------------------------------------    
def outsize_upconv(h,w,f,s,p):
    hout = (h-1)*s - 2*p + f
    wout = (w-1)*s - 2*p + f
    return hout, wout



def chunker(seq, size):
    
    # from http://stackoverflow.com/a/434328
    # not touch this code
    # -------------------
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))



# GENERIC - initialises weights for a NN
# --------------------------------------
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
    #    print(classname)
        m.weight.data.normal_(0.0, 0.02)
    #elif classname.find('Linear') != -1:
    #    print(classname)
    #    m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        
        
# GENERIC - change an torch image to numoy image
# ----------------------------------------------
def to_numpy_image(xin):
    
    try:
        xin = xin.data.numpy()
    except:
        xin = xin.numpy()
    
    
    xout = np.swapaxes(xin,1,2)
    xout = np.swapaxes(xout,2,3)
    
    # returns axes swapped numpy images
    # ---------------------------------
    return xout       



# GENERIC - converts numpy images to torch tensors for training
# -------------------------------------------------------------
def setup_image_tensor(xin):
    xout = np.swapaxes(xin,1,3)
    xout = np.swapaxes(xout,2,3)
    
    # returns axes swapped torch tensor
    # ---------------------------------
    xout = torch.from_numpy(xout)
    return xout.float()


# A functino to get linemarkings
# -------------------------------

def get_ae_output_image(x,net,use_cuda):
    
    # 0. Setting up input as torch
    # ----------------------------
    x_t = Variable(setup_image_tensor(x)).float()
        
        
    # 1. Using chunk pass to get linemarkings
    # ----------------------------------------
    xout = chunk_pass(x_t,net.eval(),False,use_cuda,1)
    xout = to_numpy_image(xout.cpu().data)
    xout = (xout * 255).astype('uint8')
    
    # 2. final return
    # ---------------
    return xout
    
    
# GENERIC class that inherits nn module and makes a sequential object a model
# ---------------------------------------------------------------------------
class Net(nn.Module):
    def __init__(self,sequencelist):
        super().__init__() # Initializing nn.Module construtors
        self.forwardpass = sequencelist
        
    def forward(self,x):
        xout = self.forwardpass(x)
        return xout

In [None]:
# Function to build a lineaf FC model
# -----------------------------------

def linear_fc(layers, nw_activations, target_activation, dropout_p):
    
    'The first value in the layers list is the input dimensions of the input'
    
    # 0. initialisations
    # ------------------
    seq_list = []
    
    # setting N/W activations
    # -------------------
    if nw_activations == 'relu':
        nw_act = nn.ReLU()
    elif nw_activations == 'lrelu':
        nw_act = nn.LeakyReLU(0.2)
    elif nw_activations == 'sigmoid':
        nw_act = nn.Sigmoid()
    elif nw_activations == 'tanh':
        nw_act = nn.Tanh()
    else:
        nw_act = nn.ReLU()
    
    # setting target activations
    # --------------------------
    if target_activation == 'sigmoid':
        target_act = nn.Sigmoid()
    elif target_activation == 'softmax':
        target_act = nn.Softmax()
    else:
        target_act = None
    
    # 1. building n/w's layer list
    # ----------------------------
    network = []
    for k in range(len(layers)):
        try:
            network.append((layers[k],layers[k+1]))
        except:
            pass
            

    
    # 2. constructing encoder n/w
    # ----------------------------
    for i in range(len(network)):
        
        # 2.1 adding linear layers to encoder
        # ------------------------------------
        curr_dims = network[i]
        seq_mod = nn.Linear(curr_dims[0],curr_dims[1])
        seq_list.append(seq_mod)
        
        # checking last layer or not
        # --------------------------
        if i+1 == len(network):
            
            # at last layer
            # -------------
            if target_act == None:
                pass
            else:
                seq_list.append(target_act)
        
        else:
            
            # batchnorm
            # ---------
            seq_list.append(nn.BatchNorm1d(curr_dims[1]))
          
            # non linear activation
            # ---------------------
            seq_list.append(nw_act)
            
            # dropout
            # -------
            seq_list.append(nn.Dropout(p = dropout_p))
           
            
    
    # 3. returning model
    # ------------------
    seq_list = nn.Sequential(*seq_list)
    seq_list.apply(weights_init)

    model = Net(seq_list)
    model = model.train()
    
    return model
            
            


In [None]:
# GENERIC model function to train the networks
# --------------------------------------------

def model_train(xin,yin,xval,yval,load_mode,model,epochs,mbsize,loss_mode,use_cuda,save_state,path):
    
    # 0. initialisations
    # ------------------
    loss_train = []
    loss_val = []
    norm_flag = 0

    

    # normalising input to 0-1 while making sure this is an image
    # -----------------------------------------------------------
    if len(xin.size()) > 3:
            
        # if the input and output are images - try will go through with the statement
        # ----------------------------------------------------------------------------
        if len(xin.size()) > 3 and len(yin.size()) > 3:

            assert torch.mean(xin).item() > 1 and torch.mean(yin).item() > 1, 'Input data is already in range 0-1. Not consistent with flow.'
            

            # both need to be normalised
            # --------------------------
            #xin = xin/255
            #yin = yin/255
            norm_flag = 1
            print('Input and Output dataset will be normalised to 0-1')

        
        # incase input and output both are NOT images
        # -------------------------------------------
        else:

            assert torch.mean(xin).item() > 1, 'Input data is already in range 0-1. Not consistent with flow.'

            # normalising input
            # -----------------
            #xin = xin/255
            norm_flag = 2
            print('Input dataset will be normalised to 0-1')
            
            
    
    # ensuring xval and yval are None
    # -------------------------------
    assert xval == None and yval == None, 'xval and yval provided, but there is no code to normalise'
    
    
    if load_mode == 'from saved':
        
        # loading from saved
        # ------------------
        model,optimizer,saved_epoch,saved_loss,saved_loss_mode = load_saved_model_function(path,use_cuda)
        model = model.train()
        loss_mode = saved_loss_mode
        print('Loading model from saved state...')
        print('Last saved loss - ' + str(saved_loss))
        print('Last saved epoch - ' + str(saved_epoch))
        epochs += int(saved_epoch)
        start_epoch = int(saved_epoch)
        
    else:
        
        # building new
        # ------------
        start_epoch = 1
        model = model.train()
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()))
        #optimizer = torch.optim.Adadelta(filter(lambda p: p.requires_grad,model.parameters()))
        #optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,model.parameters()))
        #optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,model.parameters()), lr=0.1, momentum=0.9)
        
        
    
    
    # model set up as per cuda
    # ------------------------
    if use_cuda == True:
        torch.cuda.empty_cache()
        model = model.cuda()        
    
    
    # setting loss criterion
    # ----------------------
    if loss_mode == 'MSE':
        criterion = nn.MSELoss()
    elif loss_mode == 'BCE':
        criterion = nn.BCELoss()
    elif loss_mode == 'NLL':
        criterion = nn.NLLLoss()
    elif loss_mode == 'crossentropy':
        criterion = nn.CrossEntropyLoss()
        yin = torch.max(yin.long(),1)[1]
    else:
        criterion = nn.MSELoss()
        
    
    # 1. Setting up minibatch features
    # --------------------------------
    m = xin.size()[0]
    mb_list = []
    mb_list = list(range(int(m/mbsize)))
    if m % mbsize == 0: # if the minibatches can be split up perfectly.
        'do nothing'
    else:
        mb_list.append(mb_list[len(mb_list)-1] + 1)
        
    # 2. Actual iters
    # ----------------
    for i in range(start_epoch,epochs+1):
            
        for p in mb_list:
            
            # Mini batch operations
            # ---------------------
            start_index = p*mbsize
            end_index = m if p == mb_list[len(mb_list)-1] else p*mbsize + mbsize
            m_curr = end_index - start_index
            
            Xin_mb = xin[start_index:end_index]
            Yin_mb = yin[start_index:end_index]
            
            if use_cuda == True:
                Xin_mb = Xin_mb.cuda()
                Yin_mb = Yin_mb.cuda()
                
            # normalising ops
            # --------------
            if norm_flag == 1:
                
                # normalise both input and target
                # -------------------------------
                Xin_mb = copy.deepcopy(Xin_mb)/255
                Yin_mb = copy.deepcopy(Yin_mb)/255
            
            else:
                
                # normalise only input
                # --------------------
                Xin_mb = copy.deepcopy(Xin_mb)/255
                
                
            # Network ops
            # -----------
            model_out = model(Xin_mb)
            optimizer.zero_grad()
            loss = criterion(model_out, Yin_mb) # loss(output, target)
            loss.backward()
            optimizer.step()
            loss_train.append(loss.item())
            
            # deleting curr variables
            # -----------------------
            if use_cuda == True:
                Xin_mb = Xin_mb.cpu()
                Yin_mb = Yin_mb.cpu()
                model_out = model_out.cpu()
                
                del Xin_mb
                del Yin_mb
                del model_out
                torch.cuda.empty_cache()
            
            # printing loss
            # -------------
            print('Epoch ' + str(i) + ', minibatch ' + str(p+1) + ' of '  +  str(len(mb_list)) + ' -- Model loss: ' + str(loss.item()))
            

    # 3. outside for loop saving model state
    # --------------------------------------
    if save_state == True and epochs+1 > start_epoch:
        
        # 3.1 initialising save dict
        # --------------------------
        save_dict = {}
        save_dict['epoch'] = str(i)
        save_dict['model_state_dict'] = model.cpu().state_dict()
        save_dict['optimizer_state_dict'] = optimizer.state_dict()
        save_dict['loss'] = str(loss.cpu().item())
        save_dict['loss_mode'] = loss_mode
        
        
        # 3.2 saving
        # ----------
        torch.save(save_dict,path)
        
        # saving full model to initialise a new model later on
        # ----------------------------------------------------
        torch.save(model.cpu(),path.replace('.tar','_MODEL.tar'))
        
        print('Saved.')
        
    
    # 4. return model in order to use elsewhere in the code
    # -----------------------------------------------------
    return model
        


In [None]:
# a function to load a saved model
# --------------------------------

def load_saved_model_function(path, use_cuda):
    
    
    ''' path = /folder1/folder2/model_ae.tar format'''
    
    # 1. loading full model
    # ---------------------
    model = torch.load(path.replace('.tar','_MODEL.tar'))
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()))
    
    # 2. Applying state dict
    # ----------------------
    if use_cuda == True:
        
        # loads to GPU
        # ------------
        checkpoint = torch.load(path)
        
    else:
        # loads to CPU
        # ------------
        checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
        
        
    # loading checkpoint
    # -------------------
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # loading optimizer
    # -----------------
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if use_cuda == True:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
            
            
            
    # loading other stuff
    # -------------------
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    loss_mode = checkpoint['loss_mode']
    
    return model, optimizer, epoch, loss, loss_mode
    
    


In [None]:
# simple function to do a forward pass by chunks
# ----------------------------------------------

def chunk_pass(xin,model,latent,use_cuda,chunksize):
    
    # 0. some initialisations
    # -----------------------
    model = model.eval()
    if use_cuda == True:
        torch.cuda.empty_cache()
        model = model.cuda()
        
    
    # sanity assertion
    # ----------------
    if len(xin.size()) > 3:
        
        # normalising data
        # ----------------
        assert torch.mean(xin).data[0] > 1, 'Input data is already in range 0-1. Not consistent with flow.'
        xin = xin/255
        print("Normalised data to 0-1")
       

    # 1. chuck loop
    # -------------
    with tqdm(total=xin.size()[0]) as pbar:
        for i,chunk_data in enumerate(chunker(xin, chunksize)):
            
            # forward pass ops
            # ----------------
            try:
                chunk_data = Variable(chunk_data.data, volatile=True)
            except:
                chunk_data = Variable(chunk_data, volatile=True)
                
            if use_cuda == True:
                
                torch.cuda.empty_cache()
               
                if latent == True:
                    try:
                        curr_forwardpass = model.latent(chunk_data.cuda().detach())
                    except:
                        curr_forwardpass = model.latent(chunk_data.cuda())
                else:
                    try:
                        curr_forwardpass = model(chunk_data.cuda().detach())
                    except:
                        curr_forwardpass = model(chunk_data.cuda())
            else:
                
                if latent == True:
                    try:
                        curr_forwardpass = model.latent(chunk_data.cpu().detach())
                    except:
                        curr_forwardpass = model.latent(chunk_data.cpu())
                else:
                    try:
                        curr_forwardpass = model(chunk_data.cpu().detach())
                    except:
                        curr_forwardpass = model(chunk_data.cpu())
                
            # concat ops
            # ----------
            try:
                xout = torch.cat((xout,curr_forwardpass), 0)
            except:
                xout= curr_forwardpass
                
            # for memory purpose
            # ------------------
            if use_cuda == True:
                curr_forwardpass = curr_forwardpass.cpu()
                chunk_data = chunk_data.cpu()
                del curr_forwardpass
                del chunk_data
                torch.cuda.empty_cache()
            
            pbar.update(chunksize)
        
    # 2. return
    # ---------
    xout = Variable(xout.data, volatile=False).cpu()
    

    return xout
    
    

In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class fcn_ae_4_layer(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        added_act = nn.Tanh()
        dropout_prob = 0.1
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        #  what worked - 3,64,128,256,128,64,3
        
        # Conv 1
        ###
        conv1 = 32
        ct1 = nn.Conv2d(3,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 64
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 128
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nn.Softmax2d() #nw_activation_conv
        cl3 = [ct3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3)
        
        # Conv 4
        ###
        #conv4 = 256
        #ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        #cb4 = nn.BatchNorm2d(conv4)
        #ca4 = nn.Softmax2d() #nw_activation_conv
        #cl4 = [ct4,ca4,dropout_node]
        #self.convl4 = nn.Sequential(*cl4) # size 6 x 4
        
        
        # Pooling layer
        mxpl =  [nn.MaxPool2d((3,3), stride=3)]
        #avpl =  [nn.AvgPool2d((6,4), stride=1)]
        self.pool_net = nn.Sequential(*mxpl)
        
        # Adding a fully connected linear layer
        #
        
        # Upconv layer 0
        ###
        #t0 = nn.ConvTranspose2d(conv4,conv4,2,stride = 2)
        #b0 = nn.BatchNorm2d(conv4)
        #a0 = nw_activation_conv
        #l0 = [t0,b0,a0]
        #self.upcl0 = nn.Sequential(*l0)
        
        # Upconv layer 1
        ###
        up_conv1 = 128
        t1 = nn.ConvTranspose2d(conv3,up_conv1,3,stride = 3)
        b1 = nn.BatchNorm2d(up_conv1)
        a1 = nw_activation_conv
        l1 = [t1,b1,a1,dropout_node]
        self.upcl1 = nn.Sequential(*l1)
        
        # Upconv layer 2
        ###
        up_conv2 = 64
        t2 = nn.ConvTranspose2d(up_conv1,up_conv2,f,stride = s)
        b2 = nn.BatchNorm2d(up_conv2)
        a2 = nw_activation_conv
        l2 = [t2,b2,a2,dropout_node]
        self.upcl2 = nn.Sequential(*l2)
        
        # Upconv layer 3
        ###
        up_conv3 = 32
        t3 = nn.ConvTranspose2d(up_conv2,up_conv3,f,stride = s)
        b3 = nn.BatchNorm2d(up_conv3)
        a3 = nw_activation_conv
        l3 = [t3,b3,a3,dropout_node]
        self.upcl3 = nn.Sequential(*l3)
        
        # Upconv layer 4
        ###
        t4 = nn.ConvTranspose2d(up_conv3,3,f,stride = s)
        a4 = nn.Sigmoid()
        l4 = [t4,a4]
        self.upcl4 = nn.Sequential(*l4)
        

    def forward(self, x):
        
        # Generation
        # ----------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        #c4_out = self.convl4(c3_out)
        c5_out = self.pool_net(c3_out)
        
        #f1_out = self.upcl0(c5_out)
        f2_out = self.upcl1(c5_out)
        f3_out = self.upcl2(f2_out)
        f4_out = self.upcl3(f3_out)
        f5_out = self.upcl4(f4_out)
        
        return f5_out

    
    def latent(self, x):
        
        
        # 0. forward prop
        # ---------------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        

        # 1. Working out layer number
        # ---------------------------
        if self.layer_mode == 'deep':
            forward_out = self.convl4(c3_out)
        
        elif self.layer_mode == 'deep_minus_2':
            forward_out = c2_out
        
        else:
            forward_out = c3_out


        # 2. Including a pool layer - setting dims
        # ----------------------------------------
        if self.layer_dims_set == True:
            
            ih, iw, pool_stride = self.layer_f, self.layer_f, self.layer_s
            
        else:
            
            ih,iw = forward_out.size()[2],forward_out.size()[3]
            pool_stride = 1
        
        if self.pool_mode == 'avg':

            # avg pool - comment/uncomment
            # ----------------------------
            avpl =  nn.Sequential(*[nn.AvgPool2d((ih,iw), stride=pool_stride)])
            latent_out = avpl(forward_out)
        
        elif self.pool_mode == 'max':
            
            # maxpool - comment/uncomment
            # ---------------------------
            mxpl =  nn.Sequential(*[nn.MaxPool2d((ih,iw), stride=pool_stride)])
            latent_out = mxpl(forward_out)
            
        elif self.pool_mode == 'both':
            
            # avg pool
            # --------
            avpl =  nn.Sequential(*[nn.AvgPool2d((ih,iw), stride=pool_stride)])
            latent_out_avg = avpl(forward_out)
            latent_out_avg = latent_out_avg.view(latent_out_avg.size()[0],-1)
            
            # maxpool
            # -------
            mxpl =  nn.Sequential(*[nn.MaxPool2d((ih,iw), stride=pool_stride)])
            latent_out_max = mxpl(forward_out)
            latent_out_max = latent_out_max.view(latent_out_max.size()[0],-1)
            
            # final concat
            # ------------
            latent_out = torch.cat((latent_out_avg, latent_out_max), 1)
            
        
        else:
            
            # no pooling
            # ----------
            latent_out = forward_out
            
        
        return latent_out.view(latent_out.size()[0],-1)
        
    
     
    def set_pool_mode(self, pool_mode, layer_mode, layer_dims_set, layer_f, layer_s):
        
        # setting pool mode
        # -----------------
        self.pool_mode = pool_mode
        self.layer_mode = layer_mode
        self.layer_dims_set = layer_dims_set
        self.layer_f = layer_f
        self.layer_s = layer_s
        
        print('Modes set.')


In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class fcn_ae_3_layer(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        added_act = nn.Tanh()
        dropout_prob = 0.1
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        #  what worked - 3,64,128,256,128,64,3
        
        # Conv 1
        ###
        conv1 = 64
        ct1 = nn.Conv2d(3,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 128
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 256
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nn.Softmax2d() #nn.Softmax2d() #nw_activation_conv
        cl3 = [ct3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3)
        
        # Conv 4
        ###
        #conv4 = 256
        #ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        #cb4 = nn.BatchNorm2d(conv4)
        #ca4 = nn.Softmax2d() #nw_activation_conv
        #cl4 = [ct4,ca4,dropout_node]
        #self.convl4 = nn.Sequential(*cl4) # size 6 x 4
        
        
        # Pooling layer
        #mxpl =  [nn.MaxPool2d((2,2), stride=2)]
        #avpl =  [nn.AvgPool2d((6,4), stride=1)]
        #self.pool_net = nn.Sequential(*mxpl)
        
        # Adding a fully connected linear layer
        #
        
        # Upconv layer 0
        ###
        #t0 = nn.ConvTranspose2d(conv4,conv4,2,stride = 2)
        #b0 = nn.BatchNorm2d(conv4)
        #a0 = nw_activation_conv
        #l0 = [t0,b0,a0]
        #self.upcl0 = nn.Sequential(*l0)
        
        # Upconv layer 1
        ###
        #up_conv1 = 128
        #t1 = nn.ConvTranspose2d(conv4,up_conv1,f,stride = s)
        #b1 = nn.BatchNorm2d(up_conv1)
        #a1 = nw_activation_conv
        #l1 = [t1,b1,a1,dropout_node]
        #self.upcl1 = nn.Sequential(*l1)
        
        # Upconv layer 2
        ###
        up_conv2 = 128
        t2 = nn.ConvTranspose2d(conv3,up_conv2,f,stride = s)
        b2 = nn.BatchNorm2d(up_conv2)
        a2 = nw_activation_conv
        l2 = [t2,b2,a2,dropout_node]
        self.upcl2 = nn.Sequential(*l2)
        
        # Upconv layer 3
        ###
        up_conv3 = 64
        t3 = nn.ConvTranspose2d(up_conv2,up_conv3,f,stride = s)
        b3 = nn.BatchNorm2d(up_conv3)
        a3 = nw_activation_conv
        l3 = [t3,b3,a3,dropout_node]
        self.upcl3 = nn.Sequential(*l3)
        
        # Upconv layer 4
        ###
        t4 = nn.ConvTranspose2d(up_conv3,3,f,stride = s)
        a4 = nn.Sigmoid()
        l4 = [t4,a4]
        self.upcl4 = nn.Sequential(*l4)
        

    def forward(self, x):
        
        # Generation
        # ----------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        #c4_out = self.convl4(c3_out)
        #c5_out = self.pool_net(c3_out)
        
        #f1_out = self.upcl0(c5_out)
        #f2_out = self.upcl1(c4_out)
        f3_out = self.upcl2(c3_out)
        f4_out = self.upcl3(f3_out)
        f5_out = self.upcl4(f4_out)
        
        return f5_out

    
    def latent(self, x):
        
        
        # 0. forward prop
        # ---------------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        
        # 1. Working out layer number
        # ---------------------------
        if self.layer_mode == 'deep_minus_1':
            forward_out = c2_out
        
        else:
            forward_out = self.convl3(c2_out)
        

        # 2. Including a pool layer - setting dims
        # ----------------------------------------
        if self.layer_dims_set == True:
            
            ih, iw, pool_stride = self.layer_f, self.layer_f, self.layer_s
            
        else:
            
            ih,iw = forward_out.size()[2],forward_out.size()[3]
            pool_stride = 1
        
        if self.pool_mode == 'avg':

            # avg pool - comment/uncomment
            # ----------------------------
            avpl =  nn.Sequential(*[nn.AvgPool2d((ih,iw), stride=pool_stride)])
            latent_out = avpl(forward_out)
        
        elif self.pool_mode == 'max':
            
            # maxpool - comment/uncomment
            # ---------------------------
            mxpl =  nn.Sequential(*[nn.MaxPool2d((ih,iw), stride=pool_stride)])
            latent_out = mxpl(forward_out)
            
        elif self.pool_mode == 'both':
            
            # avg pool
            # --------
            avpl =  nn.Sequential(*[nn.AvgPool2d((ih,iw), stride=pool_stride)])
            latent_out_avg = avpl(forward_out)
            latent_out_avg = latent_out_avg.view(latent_out_avg.size()[0],-1)
            
            # maxpool
            # -------
            mxpl =  nn.Sequential(*[nn.MaxPool2d((ih,iw), stride=pool_stride)])
            latent_out_max = mxpl(forward_out)
            latent_out_max = latent_out_max.view(latent_out_max.size()[0],-1)
            
            # final concat
            # ------------
            latent_out = torch.cat((latent_out_avg, latent_out_max), 1)
            
        
        else:
            
            # no pooling
            # ----------
            latent_out = forward_out
            
        
        return latent_out.view(latent_out.size()[0],-1)
        
    
     
    def set_pool_mode(self, pool_mode, layer_mode, layer_dims_set, layer_f, layer_s):
        
        # setting pool mode
        # -----------------
        self.pool_mode = pool_mode
        self.layer_mode = layer_mode
        self.layer_dims_set = layer_dims_set
        self.layer_f = layer_f
        self.layer_s = layer_s
        
        print('Modes set.')


In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class fcn_ae_deep(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        added_act = nn.Tanh()
        dropout_prob = 0.2
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        #  what worked - 3,64,128,256,128,64,3
        
        # Conv 1
        ###
        conv1 = 32
        ct1 = nn.Conv2d(3,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 64
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 128
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nw_activation_conv #nw_activation_conv
        cl3 = [ct3,cb3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3)
        
        # Conv 4
        ###
        conv4 = 256
        ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        cb4 = nn.BatchNorm2d(conv4)
        ca4 = nw_activation_conv
        cl4 = [ct4,cb4,ca4,dropout_node]
        self.convl4 = nn.Sequential(*cl4) 
        
        # Conv 5
        ###
        conv5 = 512
        ct5 = nn.Conv2d(conv4,conv5,f,stride = s)
        cb5 = nn.BatchNorm2d(conv5)
        ca5 = nn.Softmax2d()
        cl5 = [ct5,cb5,ca5,dropout_node]
        self.convl5 = nn.Sequential(*cl5) 
        
        
        # Pooling layer
        #mxpl =  [nn.MaxPool2d((2,2), stride=2)]
        #avpl =  [nn.AvgPool2d((6,4), stride=1)]
        #self.pool_net = nn.Sequential(*mxpl)
        
        # Adding a fully connected linear layer
        #
        
        # Upconv layer 0
        ###
        up_conv0 = 256
        t0 = nn.ConvTranspose2d(conv5,up_conv0,f,stride = s)
        b0 = nn.BatchNorm2d(up_conv0)
        a0 = nw_activation_conv
        l0 = [t0,b0,a0,dropout_node]
        self.upcl0 = nn.Sequential(*l0)
        
        # Upconv layer 1
        ###
        up_conv1 = 128
        t1 = nn.ConvTranspose2d(up_conv0,up_conv1,f,stride = s)
        b1 = nn.BatchNorm2d(up_conv1)
        a1 = nw_activation_conv
        l1 = [t1,b1,a1,dropout_node]
        self.upcl1 = nn.Sequential(*l1)
        
        # Upconv layer 2
        ###
        up_conv2 = 64
        t2 = nn.ConvTranspose2d(up_conv1,up_conv2,f,stride = s)
        b2 = nn.BatchNorm2d(up_conv2)
        a2 = nw_activation_conv
        l2 = [t2,b2,a2,dropout_node]
        self.upcl2 = nn.Sequential(*l2)
        
        # Upconv layer 3
        ###
        up_conv3 = 32
        t3 = nn.ConvTranspose2d(up_conv2,up_conv3,f,stride = s)
        b3 = nn.BatchNorm2d(up_conv3)
        a3 = nw_activation_conv
        l3 = [t3,b3,a3,dropout_node]
        self.upcl3 = nn.Sequential(*l3)
        
        # Upconv layer 4
        ###
        t4 = nn.ConvTranspose2d(up_conv3,3,f,stride = s)
        a4 = nn.Sigmoid()
        l4 = [t4,a4]
        self.upcl4 = nn.Sequential(*l4)
        

    
    def forward(self, x):
        
        
        # forward pass
        # ------------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        
        f1_out = self.upcl0(c5_out)
        f2_out = self.upcl1(f1_out)
        f3_out = self.upcl2(f2_out)
        f4_out = self.upcl3(f3_out)
        f5_out = self.upcl4(f4_out)
        
        return f5_out

    
    
    
    def latent(self, x):
        
        
        # 0. forward prop
        # ---------------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)

        # 1. Working out layer number
        # ---------------------------
        if self.layer_mode == 'deep':
            forward_out = self.convl5(c4_out)
        
        elif self.layer_mode == 'deep_minus_2':
            forward_out = c3_out
        
        else:
            forward_out = c4_out


        # 2. Including a pool layer 
        # -------------------------
        ih,iw = forward_out.size()[2],forward_out.size()[3]
        
        if self.pool_mode == 'avg':

            # avg pool - comment/uncomment
            # ----------------------------
            avpl =  nn.Sequential(*[nn.AvgPool2d((ih,iw), stride=1)])
            latent_out = avpl(forward_out)
        
        elif self.pool_mode == 'max':
            
            # maxpool - comment/uncomment
            # ---------------------------
            mxpl =  nn.Sequential(*[nn.MaxPool2d((ih,iw), stride=1)])
            latent_out = mxpl(forward_out)
            
        elif self.pool_mode == 'both':
            
            # avg pool
            # --------
            avpl =  nn.Sequential(*[nn.AvgPool2d((ih,iw), stride=1)])
            latent_out_avg = avpl(forward_out)
            latent_out_avg = latent_out_avg.view(latent_out_avg.size()[0],-1)
            
            # maxpool
            # -------
            mxpl =  nn.Sequential(*[nn.MaxPool2d((ih,iw), stride=1)])
            latent_out_max = mxpl(forward_out)
            latent_out_max = latent_out_max.view(latent_out_max.size()[0],-1)
            
            # final concat
            # ------------
            latent_out = torch.cat((latent_out_avg, latent_out_max), 1)
            
        
        else:
            
            # no pooling
            # ----------
            latent_out = forward_out
            
        
        return latent_out.view(latent_out.size()[0],-1)
        
    
     
    def set_pool_mode(self, pool_mode, layer_mode):
        
        # setting pool mode
        # -----------------
        self.pool_mode = pool_mode
        self.layer_mode = layer_mode
        print('Modes set.')


In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class fcn_ae_6_layer_WNET(nn.Module):
    def __init__(self, in_channels, latent_softmax):
        super().__init__()
        
        # This is WNET model
        # ------------------
        
        # Showing conv up sizes - 
        # --------------------------
        # (191,191) -- Insize
        
        # @conv1 - (95,95)
        # @conv2 - (47, 47)
        # @conv3 - (23, 23)
        # @conv4 - (11, 11)
        # @conv5 - (5, 5)
        # @conv6 - (2,2)
        # Followed by a an avg pool (2,2) to make this 1,1
        
        
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        dropout_prob = 0.2
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        # CONV Down layers
        # ----------------
        
        # Conv 1
        ###
        conv1 = 16
        ct1 = nn.Conv2d(in_channels,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 32
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 64
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nw_activation_conv
        cl3 = [ct3,cb3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3)
        
        # Conv 4
        ###
        conv4 = 128
        ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        cb4 = nn.BatchNorm2d(conv4)
        ca4 = nw_activation_conv
        cl4 = [ct4,cb4,ca4,dropout_node]
        self.convl4 = nn.Sequential(*cl4) 
        
        # Conv 5
        ###
        conv5 = 256
        ct5 = nn.Conv2d(conv4,conv5,f,stride = s)
        cb5 = nn.BatchNorm2d(conv5)
        ca5 = nw_activation_conv
        cl5 = [ct5,cb5,ca5,dropout_node]
        self.convl5 = nn.Sequential(*cl5) 
        
        # Conv 6
        ###
        conv6 = 512
        ct6 = nn.Conv2d(conv5,conv6,f,stride = s)
        cb6 = nn.BatchNorm2d(conv6)
        ca6 = nw_activation_conv
        cl6 = [ct6,cb6,ca6,dropout_node]
        self.convl6 = nn.Sequential(*cl6) 
        

        # Pooling layer + softmax activation
        # ----------------------------------
        if latent_softmax == True:
            avpl =  [nn.AvgPool2d((2,2), stride=1), nn.Softmax2d()]
        else:
            avpl =  [nn.AvgPool2d((2,2), stride=1)]
        self.pool_net = nn.Sequential(*avpl)
        
      
        # Transconv layers
        # ----------------
        # Showing conv up sizes - 
        # --------------------------
        # Incoming input is 1 x 1 x C
        # (5, 5)
        # (11, 11)
        # (23, 23)
        # (47, 47)
        # (95, 95)
        # (191, 191)
        
        # Upconv layer 0
        ###
        up_conv0 = conv6
        t0 = nn.ConvTranspose2d(conv6,up_conv0,2,stride = 1)
        b0 = nn.BatchNorm2d(up_conv0)
        a0 = nw_activation_conv
        l0 = [t0,b0,a0,dropout_node]
        self.upcl0 = nn.Sequential(*l0) # 2x2
        
        # Upconv layer 1
        # concat layer
        ###
        up_conv1 = 256
        t1 = nn.ConvTranspose2d(up_conv0 + conv6,up_conv1,f,stride = s)
        b1 = nn.BatchNorm2d(up_conv1)
        a1 = nw_activation_conv
        l1 = [t1,b1,a1,dropout_node]
        self.upcl1 = nn.Sequential(*l1) # 5x5
        
        # Upconv layer 2
        # concat layer
        ###
        up_conv2 = 128
        t2 = nn.ConvTranspose2d(up_conv1 + conv5,up_conv2,f,stride = s)
        b2 = nn.BatchNorm2d(up_conv2)
        a2 = nw_activation_conv
        l2 = [t2,b2,a2,dropout_node]
        self.upcl2 = nn.Sequential(*l2)
        
        # Upconv layer 3
        # concat layer
        ###
        up_conv3 = 64
        t3 = nn.ConvTranspose2d(up_conv2 + conv4,up_conv3,f,stride = s)
        b3 = nn.BatchNorm2d(up_conv3)
        a3 = nw_activation_conv
        l3 = [t3,b3,a3,dropout_node]
        self.upcl3 = nn.Sequential(*l3)
        
        # Upconv layer 4
        # concat layer
        ###
        up_conv4 = 32
        t4 = nn.ConvTranspose2d(up_conv3 + conv3,up_conv4,f,stride = s)
        b4 = nn.BatchNorm2d(up_conv4)
        a4 = nw_activation_conv
        l4 = [t4,b4,a4,dropout_node]
        self.upcl4 = nn.Sequential(*l4)
        
        # Upconv layer 5
        # concat layer
        ###
        up_conv5 = 16
        t5 = nn.ConvTranspose2d(up_conv4 + conv2,up_conv5,f,stride = s)
        b5 = nn.BatchNorm2d(up_conv5)
        a5 = nw_activation_conv
        l5 = [t5,b5,a5,dropout_node]
        self.upcl5 = nn.Sequential(*l5)
    
    
        # Upconv layer 6
        # concat layer - FINAL LAYER
        ###
        t6 = nn.ConvTranspose2d(up_conv5 + conv1,3,f,stride = s)
        a6 = nn.Sigmoid()
        l6 = [t6,a6]
        self.upcl6 = nn.Sequential(*l6)
        

    def forward(self, x):
        
        # Conv pass
        # ---------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        c6_out = self.convl6(c5_out)
        
        # pooling
        # -------
        latent_out = self.pool_net(c6_out)
        
        # Transconv pass
        # --------------
        f1_out = self.upcl0(latent_out)
        f2_out = self.upcl1(torch.cat((f1_out,c6_out), 1))
        f3_out = self.upcl2(torch.cat((f2_out,c5_out), 1))
        f4_out = self.upcl3(torch.cat((f3_out,c4_out), 1))
        f5_out = self.upcl4(torch.cat((f4_out,c3_out), 1))
        f6_out = self.upcl5(torch.cat((f5_out,c2_out), 1))
        f7_out = self.upcl6(torch.cat((f6_out,c1_out), 1))
        
        return f7_out

    
    def latent(self, x):
        
        
        # Conv pass
        # ---------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        c6_out = self.convl6(c5_out)
        
        # pooling
        # -------
        latent_out = self.pool_net(c6_out)

        return latent_out.view(latent_out.size()[0],-1)
        
    


In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class standard_cnn_6_layer(nn.Module):
    def __init__(self, in_channels, latent_softmax):
        super().__init__()
        
        # This is WNET model
        # ------------------
        
        # Showing conv up sizes - 
        # --------------------------
        # (191,191) -- Insize
        
        # @conv1 - (95,95)
        # @conv2 - (47, 47)
        # @conv3 - (23, 23)
        # @conv4 - (11, 11)
        # @conv5 - (5, 5)
        # @conv6 - (2,2)
        # Followed by a an avg pool (2,2) to make this 1,1
        
        
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        dropout_prob = 0.2
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        # CONV Down layers
        # ----------------
        
        # Conv 1
        ###
        conv1 = 16
        ct1 = nn.Conv2d(in_channels,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 32
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 64
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nw_activation_conv
        cl3 = [ct3,cb3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3)
        
        # Conv 4
        ###
        conv4 = 128
        ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        cb4 = nn.BatchNorm2d(conv4)
        ca4 = nw_activation_conv
        cl4 = [ct4,cb4,ca4,dropout_node]
        self.convl4 = nn.Sequential(*cl4) 
        
        # Conv 5
        ###
        conv5 = 256
        ct5 = nn.Conv2d(conv4,conv5,f,stride = s)
        cb5 = nn.BatchNorm2d(conv5)
        ca5 = nw_activation_conv
        cl5 = [ct5,cb5,ca5,dropout_node]
        self.convl5 = nn.Sequential(*cl5) 
        
        # Conv 6
        ###
        conv6 = 512
        ct6 = nn.Conv2d(conv5,conv6,f,stride = s)
        cb6 = nn.BatchNorm2d(conv6)
        ca6 = nw_activation_conv
        cl6 = [ct6,cb6,ca6,dropout_node]
        self.convl6 = nn.Sequential(*cl6) 
        

        # Pooling layer + softmax activation
        # ----------------------------------
        if latent_softmax == True:
            avpl =  [nn.AvgPool2d((2,2), stride=1), nn.Softmax2d()]
        else:
            avpl =  [nn.AvgPool2d((2,2), stride=1)]
        self.pool_net = nn.Sequential(*avpl)
        
        
        # Adding linear layers
        # -------------------
        lnt1 = nn.Linear(conv6,256)
        lnb1 = nn.BatchNorm1d(256)
        lna1 = nw_activation_conv
        ln1 = [lnt1,lnb1,lna1,dropout_node]
        self.linear1 = nn.Sequential(*ln1) 
      
        lnt2 = nn.Linear(256,128)
        lnb2 = nn.BatchNorm1d(128)
        lna2 = nw_activation_conv
        ln2 = [lnt2,lnb2,lna2,dropout_node]
        self.linear2 = nn.Sequential(*ln2)
        
        lnt3 = nn.Linear(128,64)
        lnb3 = nn.BatchNorm1d(64)
        lna3 = nw_activation_conv
        ln3 = [lnt3,lnb3,lna3,dropout_node]
        self.linear3 = nn.Sequential(*ln3)
        
        ln4 = [nn.Linear(64,1)]
        self.linear4 = nn.Sequential(*ln4)
        

    def forward(self, x):
        
        # Conv pass
        # ---------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        c6_out = self.convl6(c5_out)
        
        # pooling
        # -------
        latent_out = self.pool_net(c6_out)
        
        # linear out
        # ----------
        linear1_out = self.linear1(latent_out.view(latent_out.size()[0],-1))
        linear2_out = self.linear2(linear1_out)
        linear3_out = self.linear3(linear2_out)
        linear4_out = self.linear4(linear3_out)
        
        
        
        return linear4_out


        
    


In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class standard_cnn_4_layer(nn.Module):
    def __init__(self, in_channels, latent_softmax):
        super().__init__()
        
        # This is WNET model
        # ------------------
        
        # Showing conv up sizes - 
        # --------------------------
        # (191,191) -- Insize
        
        # @conv1 - (95,95)
        # @conv2 - (47, 47)
        # @conv3 - (23, 23)
        # @conv4 - (11, 11)
        # @conv5 - (5, 5)
        # @conv6 - (2,2)
        # Followed by a an avg pool (2,2) to make this 1,1
        
        
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        dropout_prob = 0.2
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        # CONV Down layers
        # ----------------
        
        # Conv 1
        ###
        conv1 = 64
        ct1 = nn.Conv2d(in_channels,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 128
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 256
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nw_activation_conv
        cl3 = [ct3,cb3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3)
        
        # Conv 4
        ###
        conv4 = 512
        ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        cb4 = nn.BatchNorm2d(conv4)
        ca4 = nw_activation_conv
        cl4 = [ct4,cb4,ca4,dropout_node]
        self.convl4 = nn.Sequential(*cl4) 
    
        
        # Adding linear layers
        # -------------------
        lnt1 = nn.Linear(conv4*11*11,1024)
        lnb1 = nn.BatchNorm1d(1024)
        lna1 = nw_activation_conv
        ln1 = [lnt1,lnb1,lna1,dropout_node]
        self.linear1 = nn.Sequential(*ln1) 
      
        lnt2 = nn.Linear(1024,512)
        lnb2 = nn.BatchNorm1d(512)
        lna2 = nw_activation_conv
        ln2 = [lnt2,lnb2,lna2,dropout_node]
        self.linear2 = nn.Sequential(*ln2)
        
        lnt3 = nn.Linear(512,256)
        lnb3 = nn.BatchNorm1d(256)
        lna3 = nw_activation_conv
        ln3 = [lnt3,lnb3,lna3,dropout_node]
        self.linear3 = nn.Sequential(*ln3)
        
        ln4 = [nn.Linear(256,1)]
        self.linear4 = nn.Sequential(*ln4)
        

    def forward(self, x):
        
        # Conv pass
        # ---------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        
        # linear out
        # ----------
        linear1_out = self.linear1(c4_out.view(c4_out.size()[0],-1))
        linear2_out = self.linear2(linear1_out)
        linear3_out = self.linear3(linear2_out)
        linear4_out = self.linear4(linear3_out)
        
        return linear4_out


        
    


In [None]:
# FCN class copied from image search notebook which worked
# --------------------------------------------------------

class fcn_ae_6_layer_UNET_multiple_outchannels(nn.Module):
    def __init__(self, in_channels, out_channels, latent_softmax):
        super().__init__()
        
        # This is WNET model
        # ------------------
        
        # Showing conv up sizes - 
        # --------------------------
        # (191,319) -- Insize
        
        # @conv1 - (95,159)
        # @conv2 - (47, 79)
        # @conv3 - (23, 39)
        # @conv4 - (11, 19)
        # @conv5 - (5, 9)
        # @conv6 - (2,4)
        # Followed by a an avg pool (2,4) to make this 1,1
        
        
        
        # Initialising N/W here
        # ---------------------
        nw_activation_conv = nn.ReLU() #nn.LeakyReLU(0.2) # nn.Tanh() nn.Softmax2d()
        f = 3
        s = 2
        dropout_prob = 0.2
        dropout_node = nn.Dropout2d(p=dropout_prob)
        
        # CONV Down layers
        # ----------------
        
        # Conv 1
        ###
        conv1 = 16
        ct1 = nn.Conv2d(in_channels,conv1,f,stride = s)
        cb1 = nn.BatchNorm2d(conv1)
        ca1 = nw_activation_conv
        cl1 = [ct1,cb1,ca1,dropout_node]
        self.convl1 = nn.Sequential(*cl1)
        
        # Conv 2
        ###
        conv2 = 32
        ct2 = nn.Conv2d(conv1,conv2,f,stride = s)
        cb2 = nn.BatchNorm2d(conv2)
        ca2 = nw_activation_conv
        cl2 = [ct2,cb2,ca2,dropout_node]
        self.convl2 = nn.Sequential(*cl2)
        
        # Conv 3
        ###
        conv3 = 64
        ct3 = nn.Conv2d(conv2,conv3,f,stride = s)
        cb3 = nn.BatchNorm2d(conv3)
        ca3 = nw_activation_conv
        cl3 = [ct3,cb3,ca3,dropout_node]
        self.convl3 = nn.Sequential(*cl3)
        
        # Conv 4
        ###
        conv4 = 128
        ct4 = nn.Conv2d(conv3,conv4,f,stride = s)
        cb4 = nn.BatchNorm2d(conv4)
        ca4 = nw_activation_conv
        cl4 = [ct4,cb4,ca4,dropout_node]
        self.convl4 = nn.Sequential(*cl4) 
        
        # Conv 5
        ###
        conv5 = 256
        ct5 = nn.Conv2d(conv4,conv5,f,stride = s)
        cb5 = nn.BatchNorm2d(conv5)
        ca5 = nw_activation_conv
        cl5 = [ct5,cb5,ca5,dropout_node]
        self.convl5 = nn.Sequential(*cl5) 
        
        # Conv 6
        ###
        conv6 = 512
        ct6 = nn.Conv2d(conv5,conv6,f,stride = s)
        cb6 = nn.BatchNorm2d(conv6)
        ca6 = nw_activation_conv
        cl6 = [ct6,cb6,ca6,dropout_node]
        self.convl6 = nn.Sequential(*cl6) 
        

        # Pooling layer + softmax activation
        # ----------------------------------
        pool_dims = (2,2)
        if latent_softmax == True:
            avpl =  [nn.AvgPool2d(pool_dims, stride=1), nn.Softmax2d()]
        else:
            avpl =  [nn.AvgPool2d(pool_dims, stride=1)]
        self.pool_net = nn.Sequential(*avpl)
        
      
        # Transconv layers
        # ----------------
        # Showing conv up sizes - 
        # --------------------------
        # Incoming input is 1 x 1 x C
        # (5, 5)
        # (11, 11)
        # (23, 23)
        # (47, 47)
        # (95, 95)
        # (191, 191)
        
        # Upconv layer 0
        ###
        up_conv0 = conv6
        t0 = nn.ConvTranspose2d(conv6,up_conv0,pool_dims,stride = 1)
        b0 = nn.BatchNorm2d(up_conv0)
        a0 = nw_activation_conv
        l0 = [t0,b0,a0,dropout_node]
        self.upcl0 = nn.Sequential(*l0) # 2x2
        
        # Upconv layer 1
        # concat layer
        ###
        up_conv1 = 256
        t1 = nn.ConvTranspose2d(up_conv0 + conv6,up_conv1,f,stride = s)
        b1 = nn.BatchNorm2d(up_conv1)
        a1 = nw_activation_conv
        l1 = [t1,b1,a1,dropout_node]
        self.upcl1 = nn.Sequential(*l1) # 5x5
        
        # Upconv layer 2
        # concat layer
        ###
        up_conv2 = 128
        t2 = nn.ConvTranspose2d(up_conv1 + conv5,up_conv2,f,stride = s)
        b2 = nn.BatchNorm2d(up_conv2)
        a2 = nw_activation_conv
        l2 = [t2,b2,a2,dropout_node]
        self.upcl2 = nn.Sequential(*l2)
        
        # Upconv layer 3
        # concat layer
        ###
        up_conv3 = 64
        t3 = nn.ConvTranspose2d(up_conv2 + conv4,up_conv3,f,stride = s)
        b3 = nn.BatchNorm2d(up_conv3)
        a3 = nw_activation_conv
        l3 = [t3,b3,a3,dropout_node]
        self.upcl3 = nn.Sequential(*l3)
        
        # Upconv layer 4
        # concat layer
        ###
        up_conv4 = 32
        t4 = nn.ConvTranspose2d(up_conv3 + conv3,up_conv4,f,stride = s)
        b4 = nn.BatchNorm2d(up_conv4)
        a4 = nw_activation_conv
        l4 = [t4,b4,a4,dropout_node]
        self.upcl4 = nn.Sequential(*l4)
        
        # Upconv layer 5
        # concat layer
        ###
        up_conv5 = 16
        t5 = nn.ConvTranspose2d(up_conv4 + conv2,up_conv5,f,stride = s)
        b5 = nn.BatchNorm2d(up_conv5)
        a5 = nw_activation_conv
        l5 = [t5,b5,a5,dropout_node]
        self.upcl5 = nn.Sequential(*l5)
    
    
        # Upconv layer 6
        # concat layer - FINAL LAYER
        ###
        t6 = nn.ConvTranspose2d(up_conv5 + conv1,out_channels,f,stride = s)
        a6 = nn.Sigmoid()
        l6 = [t6,a6]
        self.upcl6 = nn.Sequential(*l6)
        

    def forward(self, x):
        
        # Conv pass
        # ---------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        c6_out = self.convl6(c5_out)
        
        # pooling
        # -------
        latent_out = self.pool_net(c6_out)
        
        # Transconv pass
        # --------------
        f1_out = self.upcl0(latent_out)
        f2_out = self.upcl1(torch.cat((f1_out,c6_out), 1))
        f3_out = self.upcl2(torch.cat((f2_out,c5_out), 1))
        f4_out = self.upcl3(torch.cat((f3_out,c4_out), 1))
        f5_out = self.upcl4(torch.cat((f4_out,c3_out), 1))
        f6_out = self.upcl5(torch.cat((f5_out,c2_out), 1))
        f7_out = self.upcl6(torch.cat((f6_out,c1_out), 1))
        
        return f7_out

    
    def latent(self, x):
        
        
        # Conv pass
        # ---------
        c1_out = self.convl1(x)
        c2_out = self.convl2(c1_out)
        c3_out = self.convl3(c2_out)
        c4_out = self.convl4(c3_out)
        c5_out = self.convl5(c4_out)
        c6_out = self.convl6(c5_out)
        
        # pooling
        # -------
        latent_out = self.pool_net(c6_out)

        return latent_out.view(latent_out.size()[0],-1)
        
    


In [None]:
# END OF CODE

# 0. preparing dataset - one off task

# 0. prepaing COCO dataset - one off task

# 1. loading dataset

### 1.1 setting up model related variables

In [None]:
## SET UP CUDA OR NOT HERE + OTHER SET UPS
##########################################

dev_env = 'local' # 'gpu' or 'local'

##########################################
##########################################

# Setting CUDA
# ------------
if dev_env == 'gpu':
    use_cuda = True
else:
    use_cuda = False
if use_cuda == True:
    torch.cuda.empty_cache()
    

# SET FILE SPECIFIC NAMES HERE
# ----------------------------
if dev_env == 'gpu':
    save_path = '/home/venkateshmadhava/codes/pmate2_localgpuenv/models/'
    parent_url = '/home/venkateshmadhava/datasets/images'
else:
    save_path = '/Users/venkateshmadhava/Documents/pmate2/pmate2_env/models/'
    parent_url = '/Users/venkateshmadhava/Documents/projects/vision/object_detection/coco/'

    
# Setting up final classes as well - for BDD dataset
# --------------------------------------------------
#final_classes = ['rider', 'train', 'person', 'traffic light', 'bus', 'truck', 'motor', 'bike', 'car', 'traffic sign']
#print('\n***')
#print(final_classes)


# displaying save path
# --------------------
print('***\n')
print(save_path)

### 1.2 loading BDD dataset

### 1.3 loading COCO dataset

In [None]:
# loading COCO dataset
# --------------------

source_h5name = 'val_x_src_super_cats_person_furniture_vehicle_food_dataset.h5'
tgt_h5name = 'val_x_tgt_super_cats_person_furniture_vehicle_food_dataset.h5'
final_classes = ['person','furniture','vehicle','food']

# reading h5 file
# ---------------
hfr = h5py.File(parent_url + source_h5name, 'r')
xtrn_src_main = np.array(hfr.get('images'))

hfr = h5py.File(parent_url + tgt_h5name, 'r')
xtrn_tgt_main = np.array(hfr.get('images'))

print(xtrn_src_main.shape)
print(xtrn_tgt_main.shape)

In [None]:
# splitting set on local machine
# ------------------------------

if dev_env == 'local':
    

    xtrn_src = xtrn_src_main[0:3500]
    xtrn_tgt = xtrn_tgt_main[0:3500]
    xtst_src = xtrn_src_main[3500:]
    xtst_tgt = xtrn_tgt_main[3500:]
    
    


In [None]:
# Resizing images if required
# ----------------------------
resize_images = True
new_h,new_w = 191,191

if resize_images == True:
    
    xtrn_src = resize_all(xtrn_src,new_h,new_w)
    xtrn_tgt = resize_all(xtrn_tgt,new_h,new_w)
    
    xtst_src = resize_all(xtst_src,new_h,new_w)
    xtst_tgt = resize_all(xtst_tgt,new_h,new_w)
    
# printing shapes for sanity
# --------------------------
print(xtrn_src.shape)
print(xtrn_tgt.shape)
print(xtst_src.shape)
print(xtst_tgt.shape)    

In [None]:
# Visualising images for sanity
# -----------------------------
randrange = list(np.random.randint(xtst_src.shape[0], size=(1, 1))[0,:])

for j in randrange:
    
    print('>> Showing a training image..')
    for i in range(len(final_classes)):
        print(final_classes[i])
        plt.imshow(xtrn_src[j])
        plt.show()
        plt.imshow(xtrn_tgt[j,:,:,i])
        plt.show()
        print('------------')
    
    print('>> Showing a test image..')
    for i in range(len(final_classes)):
        print(final_classes[i])
        plt.imshow(xtst_src[j])
        plt.show()
        plt.imshow(xtst_tgt[j,:,:,i])
        plt.show()
        print('------------')


# 2. setting up & training models

In [None]:
# snippet to work out filter sizes
# --------------------------------
f = 3
s = 2
pad = 0
layers = 6

h = 191 # 255
w = 191 # 255
print('Showing conv down sizes - ')
print('--------------------------')

# showing out sizes after conv
# ----------------------------
for _ in range(layers):   
    h,w = outsize_conv(h,w,f,s,pad)
    print((h,w))
    
h = 1
w = 1
print('\nShowing conv up sizes - ')
print('--------------------------')

# showing out sizes after conv
# ----------------------------
dims = []
for _ in range(layers):   
    h,w = outsize_upconv(h,w,f,s,pad)
    dims.append((h,w))
    print((h,w))


In [None]:
# Set up train data
# -----------------
xin_train = Variable(setup_image_tensor(xtrn_src)).float()
xout_train = Variable(setup_image_tensor(xtrn_tgt)).float() * 255

print(xin_train.size())
print(xout_train.size())

### 2.2 start of model setup and training

In [None]:
# Using an FCN AE system
# Training fcn_ae_6_layer_WNET to 65 expoch brings loss to 0.014
# ----------------------

try:
    del model_ae
    print('Old model deleted.')
except:
    pass
model_ae = fcn_ae_6_layer_UNET_multiple_outchannels(3,len(final_classes),False)
model_ae.apply(weights_init)
model_ae

In [None]:
# training the model
# ------------------

''' USE -1 AS EPOCHS TO LOAD SAVED MODEL WITHOUT TRAINING '''

# model_train(xin,yin,xval,yval,load_mode,model,epochs,mbsize,loss_mode,flatten,use_cuda,save_state,path)

cn_file_name = 'fcn_coco_object_detection_6_layer_UNET_multichannelout_nonsoftmax_512.tar'
cn_save_path = save_path + cn_file_name
print(cn_save_path)


model_ae = model_train(xin_train,xout_train,None,None,'from saved',model_ae,-1,64,'mse',use_cuda,True,cn_save_path)

### 2.3 visualising results

In [None]:
# random sampling
# --------------
randrange = list(np.random.randint(xtst_src.shape[0], size=(1, 1))[0,:])

# sampling
# --------
nx_trn = Variable(setup_image_tensor(xtst_src[randrange])).float()
n_output = model_ae.eval().cpu()((nx_trn/255).cpu())
np_out = to_numpy_image(n_output.cpu().data)

for i in range(np_out.shape[0]):
    print(str(i))
    for j in range(len(final_classes)):
        print('Original - ')
        plt.imshow(xtst_src[randrange][i])
        plt.show()
        print(final_classes[j])
        plt.imshow(np_out[i,:,:,j])
        plt.show()
        print('------------------------')
    print('#####################################################')
    

In [None]:
# reading images from local folder
# --------------------------------

if dev_env == 'local':

    in_folder = '/Users/venkateshmadhava/Desktop/test'
    x_local = create_dataset_from_folder_all(in_folder,new_h,new_w)
    x_local_trn = Variable(setup_image_tensor(x_local)).float()
    print(x_local_trn.size())

    # forward pass ops
    # ----------------
    n_output = model_ae.eval().cpu()((x_local_trn/255).cpu())
    np_out = to_numpy_image(n_output.cpu().data)
    
    # showing results
    # ---------------
    for i in range(np_out.shape[0]):
        print(str(i))
        for j in range(len(final_classes)):
            print('Original - ')
            plt.imshow(x_local[i])
            plt.show()
            print(final_classes[j])
            plt.imshow(np_out[i,:,:,j])
            plt.show()
            print('------------------------')
        print('#####################################################')
    

In [None]:
# END OF TRAINING FCN

# Rough

In [None]:
#jsonurl = '/Users/venkateshmadhava/Documents/projects/vision/object_detection/coco/annotations/captions_val2017.json'

#with open(jsonurl) as f:
#    data = json.load(f)

In [None]:
#data.keys()

In [None]:
#data['annotations'][3651]