In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data 
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import random
import cv2
import time

In [None]:
def train(completion_network, context_discriminators, train_loader, test_loader, 
          n_epoch = 20, tc = 5, td = 2, alpha=0.0004,
          test_period = 5):

    try :
        print("using cuda")
        completion_network.cuda()
        context_discriminators.cuda()
        mpv.cuda()
    except :
        print("cuda not available")

    optimizer_C = torch.optim.Adadelta(completion_network.parameters())
    optimizer_D = torch.optim.Adadelta(context_discriminators.parameters())

    print_step1 = True;
    print_step2 = True;
    print_step3 = True;
    
    completion_network.train()
    context_discriminators.train()

    for t in range(n_epoch): ###################  Step 1 ###################
        
        start_time = time.time()
        
        if t < tc :

            if print_step1:  
                print("====== Training completion network ======")
                print_step1 = False
            
            running_loss_C = 0

            for i,(images, _) in enumerate(train_loader):  #sample mini batches from training data

                images = images.type(torch.cuda.FloatTensor)
                mask_c, region_c = generate_mask(input_shape=(images.shape[0], 1, images.shape[2], images.shape[3]))
                mask_c = mask_c.cuda()
                
                optimizer_C.zero_grad()

                masked_images = images - images * mask_c + mpv * mask_c
                input_completion = torch.cat((masked_images, mask_c), dim=1)
                output_completion = completion_network(input_completion)

                #if (i%100==0):
                    #masked_img = masked_images.cpu()
                    #output = output_completion.cpu()
                    #inputs = torch.cat((masked_img[0].unsqueeze(0), output[0].unsqueeze(0)))
                    #imshow_2(torchvision.utils.make_grid(inputs))


                #loss_C = mseloss(output_completion * mask_c, images * mask_c)
                loss_C = mseloss(output_completion, images)
                running_loss_C += loss_C.item()
                                
                loss_C.backward()
                optimizer_C.step()

            running_loss_C/=len(train_loader)
            completion_losses.append(running_loss_C)   
            end_time = time.time() - start_time
            minute = int(end_time/60)
            sec = int(end_time%60)
            print("Epoch {}: loss_C = {:.5f} - {} min {} s ".format(t+1, running_loss_C, minute, sec))

            if (t+1) % test_period == 0:
                # torch.save(completion_network.state_dict(), "completion_network_tc.pt")  # on Azure
                running_loss = 0
                completion_network.eval()
                with torch.no_grad():
                    for i,(test_images, _) in enumerate(test_loader):
                        test_images = test_images.type(torch.cuda.FloatTensor)
                        mask, _ = generate_mask(input_shape=(test_images.shape[0], 1, test_images.shape[2], test_images.shape[3]))
                        mask = mask.cuda()
                        masked_images = test_images - test_images * mask + mpv * mask
                        input = torch.cat((masked_images, mask), dim=1)
                        output = completion_network(input) 
                        #loss = mseloss(output * mask, test_images * mask)
                        loss = mseloss(output, test_images)
                        running_loss += loss.item()       
                running_loss/=len(test_loader)
                test_loss_C.append(running_loss)
                completion_network.train()   
                print("         test_loss_C = {:.5f}".format(running_loss))

          
        elif t < tc + td:  ###################  Step 2 ###################

            running_loss_D = 0

            if print_step2:  
                    print("====== Training context discriminators ======")
                    print_step2 = False

            for i,(images, _) in enumerate(train_loader):  #sample mini batches from training data

                images = images.type(torch.cuda.FloatTensor)
                mask_c, region_c = generate_mask(input_shape=(images.shape[0], 1, images.shape[2], images.shape[3]))
                mask_c = mask_c.cuda()

                optimizer_D.zero_grad()

                # fake images
                masked_images = images - images * mask_c + mpv * mask_c
                input_completion = torch.cat((masked_images, mask_c), dim=1)
                output_completion = completion_network(input_completion) 
                input_local_fake = crop(output_completion, region_c)
                input_global_fake = output_completion                 
                output_fake = context_discriminators((input_local_fake.cuda(), input_global_fake.cuda()))  #probability to be real - we want = 0 
                zeros = torch.zeros((len(images), 1)).cuda()
                loss_fake = bceloss(output_fake, zeros)

                # real images
                region = random_patch()
                input_local_real = crop(images, region)  
                input_global_real = images
                output_real = context_discriminators((input_local_real, input_global_real)) #probability to be real - we want = 1
                ones = torch.ones((len(images), 1)).cuda()
                loss_real = bceloss(output_real, ones)

                loss_D = 0.5 * (loss_fake + loss_real)
   
                loss_D.backward()
                optimizer_D.step()
                running_loss_D += loss_D.item()                       
            
            #if t == tc + td - 1:
                #torch.save(context_discriminators.state_dict(), "context_discriminators_td.pt")
                
            running_loss_D/=len(train_loader)   
            discriminator_losses.append(running_loss_D)   
            end_time = time.time() - start_time
            minute = int(end_time/60)
            sec = int(end_time%60)
            print("Epoch {}: loss_D = {:.5f} - {} min {} s".format(t+1, running_loss_D, minute, sec))
            print("last batch : loss fake = {:.5f}".format(loss_fake.item()), "loss real = {:.5f}".format(loss_real.item()))
      
                
        else : ###################  Step 3 ####################

            if print_step3:  
                print("====== Training both ======")
                print_step3 = False

            running_loss_C = 0
            running_loss_D = 0
            running_joint_loss = 0

            for i,(images, _) in enumerate(train_loader):  #sample mini batches from training data

                images = images.type(torch.cuda.FloatTensor)
                mask_c, region_c = generate_mask(input_shape=(images.shape[0], 1, images.shape[2], images.shape[3]))
                mask_c = mask_c.cuda()

                optimizer_C.zero_grad()
                optimizer_D.zero_grad()
                
                masked_images = images - images * mask_c + mpv * mask_c
                input_completion = torch.cat((masked_images, mask_c), dim=1)
                output_completion = completion_network(input_completion) 

                #loss_C = mseloss(output_completion * mask_c, images * mask_c)
                loss_C = mseloss(output_completion, images)

                # fake images
                input_global_fake = output_completion.detach() 
                input_local_fake = crop(input_global_fake, region_c)
                output_fake = context_discriminators((input_local_fake, input_global_fake))  #probability to be real - we want = 0 
                zeros = torch.zeros((len(images), 1)).cuda()
                loss_fake = bceloss(output_fake, zeros)

                # real images
                region = random_patch()
                input_local_real = crop(images, region)    
                input_global_real = images
                output_real = context_discriminators((input_local_real, input_global_real)) #probability to be real - we want = 1
                ones = torch.ones((len(images), 1)).cuda()
                loss_real = bceloss(output_real, ones)

                loss_D = 0.5 * alpha * (loss_fake + loss_real)
                #if loss_D.item()>0.0002 : #threshold
                loss_D.backward()
                optimizer_D.step()
                
                input_local_fake = crop(output_completion, region_c)
                input_global_fake = output_completion
                output_fake = context_discriminators((input_local_fake, input_global_fake))
                loss_fake_2 = bceloss(output_fake, ones)  #the completion network wants to minimize this loss

                joint_loss = 0.5 * (loss_C + alpha * loss_fake_2)
                joint_loss.backward()
                optimizer_C.step()     

                running_loss_C += loss_C.item()
                running_loss_D += loss_D.item()   
                running_joint_loss += joint_loss.item()  
                
            running_loss_C/=len(train_loader)
            running_loss_D/=len(train_loader) 
            running_joint_loss/=len(train_loader)
            completion_losses.append(running_loss_C)
            discriminator_losses.append(running_loss_D)   
            joint_losses.append(running_joint_loss)
            
            end_time = time.time() - start_time
            minute = int(end_time/60)
            sec = int(end_time%60)
            print("Epoch {}: loss_C = {:.5f}, loss_D = {:.5f}, joint_loss = {:.5f} - {} min {} s".format(t+1, running_loss_C, running_loss_D, 
                                                                                                         running_joint_loss, minute, sec))
            print("last batch : loss fake D = {:.5f}".format(loss_fake.item()), "loss real D = {:.5f}".format(loss_real.item()),
                 "loss_fake C = {:.5f}".format(loss_fake_2.item()))
 
            if (t+1) % test_period == 0:
                #torch.save(completion_network.state_dict(), "completion_network.pt")
                #torch.save(context_discriminators.state_dict(), "context_discriminators.pt")
                
                running_loss = 0
                completion_network.eval()
                with torch.no_grad():
                    for i,(test_images, _) in enumerate(test_loader):
                        test_images = test_images.type(torch.cuda.FloatTensor)
                        mask, _ = generate_mask(input_shape=(test_images.shape[0], 1, test_images.shape[2], test_images.shape[3]))
                        mask = mask.cuda()
                        masked_images = test_images - test_images * mask + mpv * mask
                        input = torch.cat((masked_images, mask), dim=1)
                        output = completion_network(input) 
                        #loss = mseloss(output * mask, test_images * mask)
                        loss = mseloss(output, test_images)
                        running_loss += loss.item()
             
                running_loss/=len(test_loader)
                test_loss_C.append(running_loss)
                completion_network.train()
                print("         test_loss_C = {:.5f}".format(running_loss))
                                
    print("Finished Training")
    #torch.save(completion_network.state_dict(), "completion_network.pt")
    #torch.save(context_discriminators.state_dict(), "context_discriminators.pt")
    return completion_losses, discriminator_losses, joint_losses, test_loss_C