# **1. Before Running**

**Before running any code**, make sure to install all the .zip files provided in the following links and upload the .zip files to Colab. These files are needed for the training and testing code in train.py and testing.py. There are currently four data sets that are being used right now: cropped training images, cropped validation images, and their respective masks. Note that the training images may take a bit longer to upload.

- [cropped_train_img.zip](https://drive.google.com/open?id=1175apAFzI-1g11XkRwqD3jVaQzUwXTyf)

- [cropped_valid_img.zip](https://drive.google.com/open?id=1kkRmPTaiw9ejnhf_vT14r-vQUHZPNe5U)


- [mask_train_img.zip](https://drive.google.com/open?id=1cN3GAX7xM5f20ZT3pKIRAVWm6Ho9CjCT)

- [mask_valid_img.zip](https://drive.google.com/open?id=14Z7TqCE28tkVVmR76BUUozk_QVPLr6W0)

**Note**: Newest masks are in the folder in the link below:

- [Models_and_Masks_and_Images_Folder](https://drive.google.com/drive/folders/1b4JQlOfVfJhmETd8JzUXS42l5xWO7M_D?usp=sharing)





# **2. Original Data Set Links and Explanation**

The original data sets required for our final project can be found here: [ICME19 Inpainting Challenge](https://icme19inpainting.github.io/) or if the links on that page don't work, then look here: [Inpainting Challenge Data Sets on BitaHub](https://forum.bitahub.com/views/activity-detail-en.html?activityId=_2e9eb0c6eba94190beb6430941ff4c13://).

The training and validation sets of these original data sets have the following names:
- valid_img
- train_img

The valid_img data set that you download actually contains **two types of validation images**: one for **error concealment (EC)** and the other is for **object removal (OR)**. More information is on the ICME19 Inpainting Challenge link above.

# **3. Our Current Data Sets For The Final Project**
The code we have based our work off of, bobqywei's GitHub code, works only on 256 x 256 images. However, the dataset provided for our final project are not 256 x 256. So, we have cropped all the images and the cropped images are stored in the folders titled:

- cropped_train_img
- cropped_valid_img


However, **the data on the ICME webpage only provides ground truth images**. As a result, we need to generate images with holes ourselves. This is where the **mask data sets** come in. The mask data sets that we have created generate 256x256 sized masks to match the cropping size. There are mask data sets both for the validation and training ground truth images:

- mask_train_img
- mask_valid_img

# **4. How to Run the Code on this Colab**

1. **Download** the mask and cropped data sets from the links mentioned above.
2. **Unzip** all the data sets in the folders into Colab using the code below. This may take some time as well.
3. **Pip Install** all necessary packages using the code below.
4. **Compile** all coding blocks **in order**. (i.e., first run places2_train.py, then run partial_conv_net.py, then loss.py, etc.) This is because there is a dependency between files. For example, `train.py` and `test.py` rely on `loss.py` functions, so `loss.py` functions must be compiled first.
5. (Optional) **Test** `inpaint.py`. If you guys already have a Python environment set up on your own personal computers, then it shouldn't be hard to test the `inpaint.py` yourselves if you want. The code is commented out at the end of the Colab. The `inpaint.py` code cannot run on the Colab, since it needs to run on a physical machine. To run this, you need to download one of the models in the `model` folder. Depending on which one, you may have to change the name of the model that is hardcoded in `inpaint.py`.

**Note:** If we compile the functions in the correct order, we don't need to have import statements (ex: `from places2_train import Places2Data`). `import` is only necessary if all the code is separated into separate files. However, since we copied all functions from all the files into a single Colab Notebook, **`import` isn't needed if you compile everything in the correct order.**

**Note2:** This means you guys shouldn't have to download all the files from bobqywei's GitHub. Instead, just make a copy of this Colab, make your own modifications, then copy the modifications back into this Notebook once you think code is working correctly.



# **5. Things We Need To Do**
1. Fix our model so that `inpaint.py` and `test.py` generate good results.
2. Create a larger data set using data augmentation as mentioned in class. If possible, this data set needs to be larger than 5000 images. Currently, we have been provided only 1500.
3. Make `test.py` output more quantifiable performance metrics.
4. Read more on *Partial Convolutions* and start writing the paper as we go along.
5. Explore different network models besides VGG16
6. Double check that the partial_conv_net is correctly implemented and follows the paper.

**Note:** `test.py` actually generates an output .jpg file showing the results of the inpainting called `test.jpg`. You can see how good/bad the inpainting is based on that image.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# unzips all data sets into Google Colab
!unzip cropped_train_img.zip
!unzip cropped_valid_img.zip
!unzip mask_train_img.zip
!unzip mask_valid_img.zip

unzip:  cannot find or open cropped_train_img.zip, cropped_train_img.zip.zip or cropped_train_img.zip.ZIP.
unzip:  cannot find or open cropped_valid_img.zip, cropped_valid_img.zip.zip or cropped_valid_img.zip.ZIP.
unzip:  cannot find or open mask_train_img.zip, mask_train_img.zip.zip or mask_train_img.zip.ZIP.
unzip:  cannot find or open mask_valid_img.zip, mask_valid_img.zip.zip or mask_valid_img.zip.ZIP.


In [None]:
# pip installation
!pip install tensorboardcolab
# !pip install PyQt5==5.9.2



#**places2_train.py**


In [None]:
import random
import torch
import os
import glob
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision import utils


# mean and std channel values for places2 dataset
MEAN = [0.485, 0.456, 0.406]
STDDEV = [0.229, 0.224, 0.225]


# reverses the earlier normalization applied to the image to prepare output
def unnormalize(x):
	x.transpose_(1, 3)
	x = x * torch.Tensor(STDDEV) + torch.Tensor(MEAN)
	x.transpose_(1, 3)
	return x


class Places2Data(torch.utils.data.Dataset):

	def __init__(self, path_to_data, path_to_mask):
		super().__init__()

		self.img_paths = glob.glob(os.path.dirname(os.path.abspath('')) + path_to_data + "/**/*.jpg", recursive=True)
		self.mask_paths = glob.glob(os.path.dirname(os.path.abspath('')) + path_to_mask + "/*.png")
		self.num_masks = len(self.mask_paths)
		self.num_imgs = len(self.img_paths)
		self.img_transform1 = transforms.Compose([transforms.Resize((256,256)),
  transforms.ColorJitter(hue=.05, saturation=.05),
	transforms.RandomVerticalFlip(),
	transforms.RandomCrop(size=None, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'),
  transforms.RandomHorizontalFlip(),
  transforms.RandomRotation(20)]) 
		# normalizes the image: (img - MEAN) / STD and converts to tensor
		self.img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STDDEV)])
		self.mask_transform = transforms.ToTensor()

	def __len__(self):
		return self.num_imgs

	def __getitem__(self, index):
		gt_img = Image.open(self.img_paths[index])
	
		gt_img = self.img_transform(gt_img.convert('RGB'))

		mask = Image.open(self.mask_paths[random.randint(0, self.num_masks - 1)])
		mask = self.mask_transform(mask.convert('RGB'))

		return gt_img * mask, mask, gt_img




#**partial_conv_net.py**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PartialConvLayer (nn.Module):

	def __init__(self, in_channels, out_channels, bn=True, bias=False, sample="none-3", activation="relu"):
		super().__init__()
		self.bn = bn

		if sample == "down-7":
			# Kernel Size = 7, Stride = 2, Padding = 3
			self.input_conv = nn.Conv2d(in_channels, out_channels, 7, 2, 3, bias=bias)
			self.mask_conv = nn.Conv2d(in_channels, out_channels, 7, 2, 3, bias=False)

		elif sample == "down-5":
			self.input_conv = nn.Conv2d(in_channels, out_channels, 5, 2, 2, bias=bias)
			self.mask_conv = nn.Conv2d(in_channels, out_channels, 5, 2, 2, bias=False)

		elif sample == "down-3":
			self.input_conv = nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=bias)
			self.mask_conv = nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=False)

		else:
			self.input_conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=bias)
			self.mask_conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)

		nn.init.constant_(self.mask_conv.weight, 1.0)

		# "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
		# negative slope of leaky_relu set to 0, same as relu
		# "fan_in" preserved variance from forward pass
		nn.init.kaiming_normal_(self.input_conv.weight, a=0, mode="fan_in")

		for param in self.mask_conv.parameters():
			param.requires_grad = False

		if bn:
			# Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
			# Applying BatchNorm2d layer after Conv will remove the channel mean
			self.batch_normalization = nn.BatchNorm2d(out_channels)

		if activation == "relu":
			# Used between all encoding layers
			self.activation = nn.ReLU()
		elif activation == "leaky_relu":
			# Used between all decoding layers (Leaky RELU with alpha = 0.2)
			self.activation = nn.LeakyReLU(negative_slope=0.2)

	def forward(self, input_x, mask):
		# output = W^T dot (X .* M) + b
		output = self.input_conv(input_x * mask)

		# requires_grad = False
		with torch.no_grad():
			# mask = (1 dot M) + 0 = M
			output_mask = self.mask_conv(mask)

		if self.input_conv.bias is not None:
			# spreads existing bias values out along 2nd dimension (channels) and then expands to output size
			output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(output)
		else:
			output_bias = torch.zeros_like(output)

		# mask_sum is the sum of the binary mask at every partial convolution location
		mask_is_zero = (output_mask == 0)
		# temporarily sets zero values to one to ease output calculation 
		mask_sum = output_mask.masked_fill_(mask_is_zero, 1.0)

		# output at each location as follows:
		# output = (W^T dot (X .* M) + b - b) / M_sum + b ; if M_sum > 0
		# output = 0 ; if M_sum == 0
		output = (output - output_bias) / mask_sum + output_bias
		output = output.masked_fill_(mask_is_zero, 0.0)

		# mask is updated at each location
		new_mask = torch.ones_like(output)
		new_mask = new_mask.masked_fill_(mask_is_zero, 0.0)

		if self.bn:
			output = self.batch_normalization(output)

		if hasattr(self, 'activation'):
			output = self.activation(output)

		return output, new_mask


class PartialConvUNet(nn.Module):

	# 256 x 256 image input, 256 = 2^8
	def __init__(self, input_size=256, layers=7):
		if 2 ** (layers + 1) != input_size:
			raise AssertionError

		super().__init__()
		self.freeze_enc_bn = False
		self.layers = layers

		# ======================= ENCODING LAYERS =======================
		# 3x256x256 --> 64x128x128
		self.encoder_1 = PartialConvLayer(3, 64, bn=False, sample="down-7")

		# 64x128x128 --> 128x64x64
		self.encoder_2 = PartialConvLayer(64, 128, sample="down-5")

		# 128x64x64 --> 256x32x32
		self.encoder_3 = PartialConvLayer(128, 256, sample="down-3")

		# 256x32x32 --> 512x16x16
		self.encoder_4 = PartialConvLayer(256, 512, sample="down-3")

		# 512x16x16 --> 512x8x8 --> 512x4x4 --> 512x2x2
		for i in range(5, layers + 1):
			name = "encoder_{:d}".format(i)
			setattr(self, name, PartialConvLayer(512, 512, sample="down-3"))

		# ======================= DECODING LAYERS =======================
		# dec_7: UP(512x2x2) + 512x4x4(enc_6 output) = 1024x4x4 --> 512x4x4
		# dec_6: UP(512x4x4) + 512x8x8(enc_5 output) = 1024x8x8 --> 512x8x8
		# dec_5: UP(512x8x8) + 512x16x16(enc_4 output) = 1024x16x16 --> 512x16x16
		for i in range(5, layers + 1):
			name = "decoder_{:d}".format(i)
			setattr(self, name, PartialConvLayer(512 + 512, 512, activation="leaky_relu"))

		# UP(512x16x16) + 256x32x32(enc_3 output) = 768x32x32 --> 256x32x32
		self.decoder_4 = PartialConvLayer(512 + 256, 256, activation="leaky_relu")

		# UP(256x32x32) + 128x64x64(enc_2 output) = 384x64x64 --> 128x64x64
		self.decoder_3 = PartialConvLayer(256 + 128, 128, activation="leaky_relu")

		# UP(128x64x64) + 64x128x128(enc_1 output) = 192x128x128 --> 64x128x128
		self.decoder_2 = PartialConvLayer(128 + 64, 64, activation="leaky_relu")

		# UP(64x128x128) + 3x256x256(original image) = 67x256x256 --> 3x256x256(final output)
		self.decoder_1 = PartialConvLayer(64 + 3, 3, bn=False, activation="", bias=True)
	
	def forward(self, input_x, mask):
		encoder_dict = {}
		mask_dict = {}

		key_prev = "h_0"
		encoder_dict[key_prev], mask_dict[key_prev] = input_x, mask

		for i in range(1, self.layers + 1):
			encoder_key = "encoder_{:d}".format(i)
			key = "h_{:d}".format(i)
			# Passes input and mask through encoding layer
			encoder_dict[key], mask_dict[key] = getattr(self, encoder_key)(encoder_dict[key_prev], mask_dict[key_prev])
			key_prev = key

		# Gets the final output data and mask from the encoding layers
		# 512 x 2 x 2
		out_key = "h_{:d}".format(self.layers)
		out_data, out_mask = encoder_dict[out_key], mask_dict[out_key]

		for i in range(self.layers, 0, -1):
			encoder_key = "h_{:d}".format(i - 1)
			decoder_key = "decoder_{:d}".format(i)

			# Upsample to 2 times scale, matching dimensions of previous encoding layer output
			out_data = F.interpolate(out_data, scale_factor=2)
			out_mask = F.interpolate(out_mask, scale_factor=2)

			# concatenate upsampled decoder output with encoder output of same H x W dimensions
			# s.t. final decoding layer input will contain the original image
			out_data = torch.cat([out_data, encoder_dict[encoder_key]], dim=1)
			# also concatenate the masks
			out_mask = torch.cat([out_mask, mask_dict[encoder_key]], dim=1)
			
			# feed through decoder layers
			out_data, out_mask = getattr(self, decoder_key)(out_data, out_mask)

		return out_data

	def train(self, mode=True):
		super().train(mode)
		if self.freeze_enc_bn:
			for name, module in self.named_modules():
				if isinstance(module, nn.BatchNorm2d) and "enc" in name:
					# Sets batch normalization layers to evaluation mode
					module.eval()

# **Loss.py** 

In [None]:
import torch
import os
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
import pdb

from torchvision import models
from torchvision import transforms
from PIL import Image

# from places2_train import Places2Data, MEAN, STDDEV

LAMBDAS = {"valid": 1.0, "hole": 6.0, "tv": 2.0, "perceptual": 0.05, "style": 240.0}


def gram_matrix(feature_matrix):
	(batch, channel, h, w) = feature_matrix.size()
	feature_matrix = feature_matrix.view(batch, channel, h * w)
	feature_matrix_t = feature_matrix.transpose(1, 2)

	# batch matrix multiplication * normalization factor K_n
	# (batch, channel, h * w) x (batch, h * w, channel) ==> (batch, channel, channel)
	gram = torch.bmm(feature_matrix, feature_matrix_t) / (channel * h * w)

	# size = (batch, channel, channel)
	return gram


def perceptual_loss(h_comp, h_out, h_gt, l1):
	loss = 0.0

	for i in range(len(h_comp)):
		loss += l1(h_out[i], h_gt[i])
		loss += l1(h_comp[i], h_gt[i])

	return loss


def style_loss(h_comp, h_out, h_gt, l1):
	loss = 0.0

	for i in range(len(h_comp)):
		loss += l1(gram_matrix(h_out[i]), gram_matrix(h_gt[i]))
		loss += l1(gram_matrix(h_comp[i]), gram_matrix(h_gt[i]))

	return loss


# computes TV loss over entire composed image since gradient will not be passed backward to input
def total_variation_loss(image, l1):
    # shift one pixel and get loss1 difference (for both x and y direction)
    loss = l1(image[:, :, :, :-1], image[:, :, :, 1:]) + l1(image[:, :, :-1, :], image[:, :, 1:, :])
    return loss


class VGG16Extractor(nn.Module):
	def __init__(self):
		super().__init__()
		vgg16 = models.vgg16(pretrained=True)
		self.max_pooling1 = vgg16.features[:5]
		self.max_pooling2 = vgg16.features[5:10]
		self.max_pooling3 = vgg16.features[10:17]

		for i in range(1, 4):
			for param in getattr(self, 'max_pooling{:d}'.format(i)).parameters():
				param.requires_grad = False

	# feature extractor at each of the first three pooling layers
	def forward(self, image):
		results = [image]
		for i in range(1, 4):
			func = getattr(self, 'max_pooling{:d}'.format(i))
			results.append(func(results[-1]))
		return results[1:]


class CalculateLoss(nn.Module):
	def __init__(self):
		super().__init__()
		self.vgg_extract = VGG16Extractor()
		self.l1 = nn.L1Loss()

	def forward(self, input_x, mask, output, ground_truth):
		composed_output = (input_x * mask) + (output * (1 - mask))

		fs_composed_output = self.vgg_extract(composed_output)
		fs_output = self.vgg_extract(output)
		fs_ground_truth = self.vgg_extract(ground_truth)

		loss_dict = dict()

		loss_dict["hole"] = self.l1((1 - mask) * output, (1 - mask) * ground_truth) * LAMBDAS["hole"]
		loss_dict["valid"] = self.l1(mask * output, mask * ground_truth) * LAMBDAS["valid"]
		loss_dict["perceptual"] = perceptual_loss(fs_composed_output, fs_output, fs_ground_truth, self.l1) * LAMBDAS["perceptual"]
		loss_dict["style"] = style_loss(fs_composed_output, fs_output, fs_ground_truth, self.l1) * LAMBDAS["style"]
		loss_dict["tv"] = total_variation_loss(composed_output, self.l1) * LAMBDAS["tv"]

		return loss_dict



#**Train.py**

In [None]:
import argparse
import os
import torch
import numpy as np
import easydict
from torch.utils import data
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import pdb


class SubsetSampler(data.sampler.Sampler):
	def __init__(self, start_sample, num_samples):
		self.num_samples = num_samples
		self.start_sample = start_sample

	def __iter__(self):
		return iter(range(self.start_sample, self.num_samples))

	def __len__(self):
		return self.num_samples


def requires_grad(param):
	return param.requires_grad


if __name__ == '__main__':

  args = easydict.EasyDict({
      "train_path": "/cropped_train_img",
      "mask_path": "/mask_train_img",
      "val_path": "/valid_img/data_png/error_concealment_valid",
      "log_dir": "/training_logs",
      "save_dir": "/model",
      "load_model": "",
      "lr": 5e-3,
      "fine_tune_lr": 5e-3,
      "batch_size": 16,
      "epochs": 15,
      "fine_tune": 0,
      "gpu": 0,
      "num_workers": 0,
      "log_interval": 10,
      "save_interval": 5000
  })
  cwd = os.getcwd()
  print(f"cwd: {cwd}")
  


	#Tensorboard SummaryWriter setup
  if not os.path.exists(cwd + args.log_dir):
		  os.makedirs(cwd + args.log_dir)

  writer = SummaryWriter(cwd + args.log_dir)

  if not os.path.exists(cwd + args.save_dir):
		  os.makedirs(cwd + args.save_dir)

  if args.gpu >= 0:
      device = torch.device("cuda:{}".format(args.gpu))
  else:
      device = torch.device("gpu")

  print(f"Path to training ground truth images: {cwd+args.train_path}")
  print(f"Path to training mask images: {cwd+args.mask_path}")

  data_train = Places2Data(cwd+args.train_path, cwd+args.mask_path)
  data_size = len(data_train)
  print("Loaded training dataset with {} samples and {} masks".format(data_size, data_train.num_masks))

	# assert(data_size % args.batch_size == 0)
  iters_per_epoch = data_size // args.batch_size

	# data_val = Places2Data(args.val_path, args.mask_path)
	# print("Loaded validation dataset...")

	# Move model to gpu prior to creating optimizer, since parameters become different objects after loading
  model = PartialConvUNet().to(device)
  print("Loaded model to device...")

	# Set the fine tune learning rate if necessary
  if args.fine_tune:
      lr = args.fine_tune_lr
      model.freeze_enc_bn = True
  else:
      lr = args.lr

	# Adam optimizer proposed in: "Adam: A Method for Stochastic Optimization"
	# filters the model parameters for those with requires_grad == True
  optimizer = torch.optim.Adam(filter(requires_grad, model.parameters()), lr=lr)
  print("Setup Adam optimizer...")

	# Loss function
	# Moves vgg16 model to gpu, used for feature map in loss function
  loss_func = CalculateLoss().to(device)
  print("Setup loss function...")

	# Resume training on model
  if args.load_model:
      assert os.path.isfile(cwd + args.save_dir + args.load_model)

      filename = cwd + args.save_dir + args.load_model
      checkpoint_dict = torch.load(filename)

      model.load_state_dict(checkpoint_dict["model"])
      optimizer.load_state_dict(checkpoint_dict["optimizer"])

      print("Resume training on model:{}".format(args.load_model))

		# Load all parameters to gpu
      model = model.to(device)
      for state in optimizer.state.values():
          for key, value in state.items():
            if isinstance(value, torch.Tensor):
              state[key] = value.to(device)

  for epoch in range(0, args.epochs):

      iterator_train = iter(data.DataLoader(data_train, 
											batch_size=args.batch_size, 
											num_workers=args.num_workers, 
											sampler=SubsetSampler(0, data_size)))

		# TRAINING LOOP
      print("\nEPOCH:{} of {} - starting training loop from iteration:0 to iteration:{}\n".format(epoch, args.epochs, iters_per_epoch))
		
      for i in tqdm(range(0, iters_per_epoch)):

			# Sets model to train mode
        model.train()

        # Gets the next batch of images
        image, mask, gt = [x.to(device) for x in next(iterator_train)]
        
        # Forward-propagates images through net
        # Mask is also propagated, though it is usually gone by the decoding stage
        output = model(image, mask)

        loss_dict = loss_func(image, mask, output, gt)
        loss = 0.0

			# sums up each loss value
        for key, value in loss_dict.items():
          loss += value
          if (i + 1) % args.log_interval == 0:
            writer.add_scalar(key, value.item(), (epoch * iters_per_epoch) + i + 1)
            writer.file_writer.flush()

        # Resets gradient accumulator in optimizer
        optimizer.zero_grad()
        # back-propogates gradients through model weights
        loss.backward()
        # updates the weights
        optimizer.step()

        # Save model
        if (i + 1) % args.save_interval == 0 or (i + 1) == iters_per_epoch:
          filename = cwd + args.save_dir + "/model_e{}_i{}.pth".format(epoch, i + 1)
          state = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
          torch.save(state, filename)

  writer.close()

#**Test.py**

In [None]:
import argparse
import torch
import os
import random
import pdb
import matplotlib.pyplot as plt
import torchvision
import numpy as np


from PIL import Image
from torchvision.utils import make_grid
from torchvision.utils import save_image
from torchvision import transforms
from skimage.measure import compare_psnr as psnr
import skimage
from skimage.metrics import structural_similarity as ssim
from places2_train import MEAN, STDDEV, unnormalize
from partial_conv_net import PartialConvUNet
from loss import CalculateLoss


def display_image(img):
	plt.figure()
	plt.imshow(img)
	plt.show()

image_num = str(random.randint(0, 69))
mask_num = str(random.randint(0, 1749))




args = easydict.EasyDict({
      "img": "/valid_img/{image_num}_0.jpg",
      "mask": "/mask_valid_img/mask_{}.png".format(image_num),
      "model": "/model/model_e99_i96.pth",
      "size": 256
  })



img_list = []
mask_list = []

loss_hole = []
loss_valid = []
loss_perceptual = []
loss_style = []
loss_tv = []
loss_ssim = []
loss_psnr = []

# random.seed(0)

samples = 100

for i in range(samples):

    # based on the structure of the ICME 2019 Challenge Dataset
    folder_num = random.randint(0,1)
    if(folder_num == 0):
        val_type = "error_concealment_valid"
    else:
        val_type = "object_removal_valid"
    image_num = str(random.randint(0, 69))
    mask_num = image_num
    mask_subnum = str(random.randint(0, 69))
    img_name = f"/valid_img/data_png/valid_img/{image_num}.jpg"
    mask_name = f"/valid_img/data_png/{val_type}/{mask_num}/{mask_num}_{mask_subnum}.png"

    # # For testing custom data set
    # image_num = str(random.randint(0, 69))
    # image_subnum = str(random.randint(0,4))
    # mask_num = str(random.randint(0,1749))
    # img_name = f"/cropped_valid_img/{image_num}_{image_subnum}.jpg"
    # mask_name = f"/mask_valid_img/mask_{mask_num}.png"

    print(f"img, mask: {img_name}, {mask_name}")
    img_list.append(img_name)
    mask_list.append(mask_name)

cwd = os.getcwd()
device = torch.device("cpu")


img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STDDEV)])
mask_transform = transforms.ToTensor()
transform_resize = transforms.Resize(size=(256,256))


# earlier epoch
checkpoint_dict = torch.load(cwd + args.model, map_location="cpu")
model = PartialConvUNet()
model.load_state_dict(checkpoint_dict["model"])
model = model.to(device)
model.eval()


# later epoch
checkpoint_dict2 = torch.load(cwd + "/model/model_e152_i96.pth", map_location="cpu")
model2 = PartialConvUNet()
model2.load_state_dict(checkpoint_dict2["model"])
model2 = model2.to(device)
model2.eval()


for k,(img,mask) in enumerate(zip(img_list,mask_list)):
    print(f"img: {img}, mask = {mask}")
    # with Image.open(cwd + img) as gt_img, Image.open(cwd + mask) as mask:
        
    mask = Image.open(cwd + mask)
    mask = transform_resize(mask)
    mask = mask_transform(mask.convert("RGB"))

    gt_img = Image.open(cwd + img)
    gt_img = transform_resize(gt_img)
    gt_img = img_transform(gt_img.convert("RGB"))
    img = gt_img * mask

    img.unsqueeze_(0)
    gt_img.unsqueeze_(0)
    mask.unsqueeze_(0)

    
    # for the first model
    with torch.no_grad():
        output = model(img.to(device), mask.to(device))

    output = (mask * img) + ((1 - mask) * output)


    # for the second model
    with torch.no_grad():
        output2 = model2(img.to(device), mask.to(device))

    output2 = (mask * img) + ((1 - mask) * output2)
    loss_func = CalculateLoss()
    loss_out = loss_func(img, mask, output2, gt_img)


    gt_img_ssim = gt_img.squeeze(0)
    output2_ssim = output2.squeeze(0)


    gt_img_ssim = torchvision.transforms.ToPILImage()(gt_img_ssim)
    output2_ssim = torchvision.transforms.ToPILImage()(output2_ssim)
    
    
    for i,(key, value) in enumerate(loss_out.items()):
        # print("KEY:{} | VALUE:{}".format(key, value))
        if(key == "hole"):
            loss_hole.append(value)
        if(key == "valid"):
            loss_valid.append(value)
        if(key == "perceptual"):
            loss_perceptual.append(value)
        if(key == "tv"):
            loss_tv.append(value)
        if(key == "style"):
            loss_style.append(value)
    
        
        
    gt_img_ssim = np.array(gt_img_ssim)
    output2_ssim = np.array(output2_ssim)
    out_ssim = ssim(gt_img_ssim,output2_ssim, multichannel=True)
    out_psnr = psnr(gt_img_ssim,output2_ssim)
    loss_ssim.append(out_ssim)
    loss_psnr.append(out_psnr)


avg_hole = sum(loss_hole)/len(loss_hole)
avg_valid = sum(loss_valid)/len(loss_valid)
avg_tv = sum(loss_tv)/len(loss_tv)
avg_perceptual = sum(loss_perceptual)/len(loss_perceptual)
avg_style = sum(loss_style)/len(loss_style)
avg_ssim = sum(loss_ssim)/len(loss_ssim)
avg_psnr = sum(loss_psnr)/len(loss_psnr)

print(f"Average loss: hole value - {avg_hole}")
print(f"Average loss: valid value - {avg_valid}")
print(f"Average loss: tv value - {avg_tv}")
print(f"Average loss: perceptual value - {avg_perceptual}")
print(f"Average loss: style value - {avg_style}")

print(f"Average: PSNR value - {avg_psnr}")
print(f"Average: ssim value - {avg_ssim}")



grid = make_grid(torch.cat((unnormalize(gt_img, MEAN, STDDEV), unnormalize(img, MEAN, STDDEV), unnormalize(output, MEAN, STDDEV), unnormalize(output2, MEAN, STDDEV)), dim=0))
plt.figure()
plt.imshow(torchvision.transforms.ToPILImage()(grid))
save_image(grid, "test.jpg")