In [None]:
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image
import os
import glob
import itertools
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

device = torch.device("cuda")

import os    
import cv2
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import torchvision.transforms as transforms
from torchvision.models import vgg19, VGG19_Weights

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, n_features,norm_layer=nn.InstanceNorm2d):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(n_features, n_features, 3),
            norm_layer(n_features), 
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(n_features, n_features, 3),
            norm_layer(n_features)
        )
    def forward(self, x):
        return x + self.block(x)

class ResnetGenerator(nn.Module):
    def __init__(self,num_residual_blocks=6,default_norm_layer=nn.InstanceNorm2d):
        super(ResnetGenerator, self).__init__()  
        
        def conv_block(in_channels,out_channels,kernel_size,stride,padding,norm_layer=default_norm_layer,padding_mode='reflect'):
            if norm_layer != False:
                return [nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,padding_mode=padding_mode),
                        norm_layer(out_channels),
                        nn.ReLU(inplace=True)]
            else:
                return [nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,padding_mode=padding_mode),
                        nn.ReLU(inplace=True)]
            
        # Used in upsampling
        def deconv_block(in_channels,out_channels,kernel_size=3,stride=2,padding=1, output_padding=1,norm_layer=default_norm_layer):
            if norm_layer != False:
                return [nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,output_padding),
                        norm_layer(out_channels),
                        nn.ReLU(inplace=True)]
            else:
                return [nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,output_padding),
                        nn.ReLU(inplace=True)]
         
        output_layer = [nn.Conv2d(in_channels=64,out_channels=3,kernel_size=7,padding=3,padding_mode='reflect'),nn.Tanh()]

        self.model = nn.Sequential(
            *conv_block(in_channels=3,out_channels=64,kernel_size=7,stride=1,padding=3,padding_mode='reflect'),
            *conv_block(in_channels=64,out_channels=128,kernel_size=3,stride=2,padding=1,padding_mode='zeros'),
            *conv_block(in_channels=128,out_channels=256,kernel_size=3,stride=2,padding=1,padding_mode='zeros'),
            *[ResidualBlock(256,default_norm_layer)]*num_residual_blocks,
            *deconv_block(in_channels=256,out_channels=128),
            *deconv_block(in_channels=128,out_channels=64),
            *output_layer)
        
    def forward(self, x):
        return self.model(x)

In [None]:
# G_BA_9 = ResnetGenerator(9).to(device)
# G_BA_15 = ResnetGenerator(15).to(device)
# G_BA_30 = ResnetGenerator(30).to(device)


In [None]:
# G_BA_9.load_state_dict(torch.load('GBA 9 rblocks'))
# G_BA_15.load_state_dict(torch.load('GBA 15 rblocks'))
# G_BA_30.load_state_dict(torch.load('GBA 30 rblocks'))


In [None]:
# G_BA_impress = ResnetGenerator(9).to(device)
# G_BA_impress.load_state_dict(torch.load('GBA-landscape-epoch45'))
G_BA_impress = ResnetGenerator(15).to(device)
G_BA_impress.load_state_dict(torch.load('GBA 15 rblocks'))

In [None]:
source_video_dir = './source_videos/demo3.mp4'
video_frames_dir = "./source_frames/demo3/"

vidcap = cv2.VideoCapture(source_video_dir)
def getFrame(sec):
    vidcap.set(cv2.CAP_PROP_POS_MSEC,sec*1000)
    hasFrames,image = vidcap.read()
    if hasFrames:
        cv2.imwrite(video_frames_dir+str(count)+".jpg", image)     # save frame as JPG file
    return hasFrames
sec = 0
fps_source = 24
frameRate = 1/fps_source
count=1
success = getFrame(sec)
while success:
    count = count + 1
    sec = sec + frameRate
    sec = round(sec, 2)
    success = getFrame(sec)

In [None]:
video_frames_dir = "./source_frames/demo3/"

In [None]:
imsize = (256*2,256*2)
transform = transforms.Compose([
    transforms.Resize(imsize),  # scale imported image
    transforms.ToTensor(),
#     transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) # match Normalize params of loaded model
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # match Normalize params of loaded model
])
def image_loader(image_name):
    image = Image.open(image_name)
    image = transform(image).unsqueeze(0)
    return image

# invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
#                                                      std = [ 1/0.229, 1/0.224, 1/0.225 ]),
#                                 transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
#                                                      std = [ 1., 1., 1. ]),
#                                ])

invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.5, 1/0.5, 1/0.5 ]),
                                transforms.Normalize(mean = [ -0.5, -0.5, -0.5 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

In [None]:
img_name = video_frames_dir + str(100) + '.jpg'
with torch.no_grad():
    G_BA_impress.eval()
    original_B = image_loader(img_name).to(device)
    generated_A = G_BA_impress(original_B).detach()
#     generated_A = F.interpolate(generated_A, scale_factor=3, mode='bicubic')

In [None]:
plt.imshow(invTrans(generated_A).squeeze(0).cpu().permute(1,2,0))
plt.show()

In [None]:
import os
import cv2
from os.path import isfile, join

from tqdm import tqdm

frame_files = [f for f in os.listdir(video_frames_dir) if isfile(join(video_frames_dir, f))]
generated_frames = []
with torch.no_grad():
    G_BA_impress.eval()
    for i in tqdm(range(len(frame_files))):
        content_img = video_frames_dir + str(i+1) + '.jpg'
        original_B = image_loader(content_img).to(device)
        generated_A = G_BA_impress(original_B).detach()
        generated_frames.append(invTrans(generated_A).squeeze(0).cpu().permute(1,2,0))

In [None]:
fps = 60
pathOut = 'video_0.mp4'
generated_frames_array = [(i*255).numpy().round().astype(np.uint8)[:, :, ::-1] for i in generated_frames]
out = cv2.VideoWriter(pathOut,cv2.VideoWriter_fourcc(*'DIVX'), fps, imsize[::-1])
for i in range(len(generated_frames_array)):
    out.write(generated_frames_array[i])
out.release()