In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2 as cv
import glob
import sys
import os
import copy
import argparse
import random
import tqdm
import time
from PIL import Image
import shutil
import subprocess
from multiprocessing import Pool

import torch
from torch.autograd import Variable


np.set_printoptions(suppress=True)
%load_ext autoreload
%autoreload 2
%matplotlib inline
plt.rcParams['figure.figsize'] = [15, 8]

In [None]:
# Import Crop GAN related libs
gan_dir = os.path.abspath("../src/")
sys.path.append(gan_dir)
from options.train_options import TrainOptions
# from options.test_options import TestOptions

from data import create_dataset
from models import create_model
import util.util_detector as util_detector
from models.yolo_model import Darknet
import util.util as utils
from util.dataset_yolo import ListDataset


## 1. Setup Model
The most important args to be set are:  
--checkpoints_dir:  # Set to the location of model  
--name: # the folder name 



In [None]:
arguments = "--model double_task_cycle_gan\
             --checkpoints_dir ../data/models/ \
             --name Sythetic2bordenNight\
             --no_flip\
             --num_threads 0\
             --gpu_ids 0\
             --display_id -1\
             --preprocess resize_and_crop\
             --load_size 416\
             --crop_size 416\
             --batch_size 1 \
             --task_model_def ../src/config/yolov3-tiny.cfg" 

opt = TrainOptions().parse_notebook(arguments.split())

In [None]:
model = create_model(opt)      # create a model given opt.model and other options
model.setup(opt)               # regular setup: load and print networks; create schedulers
model.eval()

## 2. Load pretrained mmodel weights


In [None]:
load_suffix = "../data/models/Sythetic2bordenNight/latest"
model.load_networks_from_folder(load_suffix)

## 3. Load an systhetic (Domain A) image you want to transfer


### Loop through all the synthetic images, generate realistic ones.

In [None]:
def tensor_to_image(tensor):
    tensor = tensor*255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
        tensor = tensor[0]
    return Image.fromarray(tensor)

out_path = "/home/michael/ucdavis/CropGANData/gan_created_images/cropgan_default/"
in_path = "/home/michael/ucdavis/CropGANData/crop_gan_data/sytheticVis2bordenNight/"

In [None]:
synth_images = glob.glob(in_path + "trainA/*.jpg")
with torch.no_grad():

    for img_path in synth_images:
        raw_image = Image.open(img_path).convert('RGB')
        raw_img_tensor, raw_img_np = utils.preprocess_images(raw_image,resize=[416,416])
        fake_img = model.netG_A(raw_img_tensor)
        fake_img_np = fake_img.detach().cpu().squeeze(0).permute([1, 2, 0])*0.5+0.5
        im = tensor_to_image(fake_img_np)
        im.save(out_path + img_path.split('/')[-1])


In [None]:
print(raw_img_tensor.min(),raw_img_tensor.max())
print( fake_img.min(), fake_img_np.min())
print( fake_img.max(), fake_img_np.max())

In [None]:
# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image

# define a transform to convert a tensor to PIL image
transform = T.ToPILImage()

# convert the tensor to PIL image using above transform
img = transform(fake_img[0])

In [None]:
fake_img_np.shape

In [None]:
fake_img.shape

In [None]:
image_a_path = "../data/samples/sythetic/00054.jpg"
real_a_img = Image.open(image_a_path).convert('RGB')
real_a_img_tensor, real_a_img_np = utils.preprocess_images(real_a_img)
plt.imshow(real_a_img)

## 4. Generate semantically constrained GAN image


In [None]:
with torch.no_grad():
    fake_b_img = model.netG_A(real_a_img_tensor)
fake_b_img_np = fake_b_img.detach().cpu().squeeze(0).permute([1, 2, 0])*0.5+0.5

In [None]:
real_a_img_resize = real_a_img.resize([256, 256])
fig, (ax0, ax1) = plt.subplots(1, 2)
ax0.imshow(real_a_img_resize)
ax1.imshow(fake_b_img_np)

In [None]:
pred_A = model.netD_A(fake_b_img)
pred_B = model.netD_B(fake_b_img)
print(pred_A.mean())
print(pred_B.mean())