In [None]:
import glob
import os
import os.path as ospath
from os import listdir, scandir
from os.path import isfile, join, exists
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as torchdata
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torchvision
from torchvision import models, transforms
from torchvision.transforms import Compose,ToTensor, ToPILImage, Resize,Normalize
from torch.autograd import Variable

import cv2
import time
import math
import pandas as pd

In [None]:

scale_factor = 8
mysize = 512
mysize -= mysize%scale_factor

EPOCH = 100
batch_size = 3

In [None]:
class TrainDatasetFromPath(Dataset):
    def __init__(self, data_path, scale_factor):
        super(TrainDatasetFromPath,self).__init__()
        self.image_list = make_datapath_list(train_data_path)
        self.hd_transformer = HDTransformer(mysize,mean,std)
        self.ld_transformer = LDTransformer(mysize, scale_factor)
    
    def __getitem__(self, index):
        image_path = self.image_list[index]
        image = Image.open(image_path).convert('RGB')
        hd_transformed = self.hd_transformer(image)
        ld_transformed = self.ld_transformer(hd_transformed)
        
        return ld_transformed, hd_transformed
    
    def __len__(self):
        return len(self.image_list)

In [None]:
class ValDatasetFromPath(Dataset):
    def __init__(self,data_path,scale_factor):
        super(ValDatasetFromPath, self).__init__()
        self.scale_factor = scale_factor
        self.image_list = make_datapath_list(val_data_path)
        
    def __getitem__(self,index):
        hd_data = Image.open(self.image_list[index]).convert('RGB')
        w, h = hd_data.size
        w -= w%self.scale_factor
        h -= h%self.scale_factor
        
        ld_scaler = Resize((mysize//scale_factor,mysize//scale_factor), interpolation=Image.BICUBIC)
        hd_scaler = Resize((mysize,mysize),interpolation=Image.BICUBIC)
        hd_data = hd_scaler(hd_data)
        ld_data = ld_scaler(hd_data)
        hd_restored = hd_scaler(ld_data)
#         print('ld_data {}/hd_data {}/restore_data {}'.format(ld_data.size, hd_data.size,hd_restored.size))
        return ToTensor()(ld_data),ToTensor()(hd_restored), ToTensor()(hd_data)
    
    def __len__(self):
        return len(self.image_list)

In [None]:
# wa = Image.open(val_data_path + '/wa.png').convert('RGB')

#  wa_scaler = Resize((mysize,mysize),interpolation=Image.BICUBIC)
# wa_data =wa_scaler(wa)    
# print(wa_data)
# wa_loader = DataLoader(dataset=wa_data, num_workers=0, batch_size=1, shuffle=True)
# wa_sr = G(wa_data)


In [None]:
class TestDatasetFromPath(Dataset):
    def __init__(self, data_path, scale_factor):
        super(TestDatasetFromPath,self).__init__()
        self.ld_path = data_path +'SR_' + str(scale_factor) + '/data/'
        self.hd_path = data_path + 'SR_' + str(scale_factor) + '/target/'
        self.scale_factor = scale_factor
        self.ld_list = [join(self.ld_path,x) for x in listdir(self.ld_path)]
        self.hd_list = [join(self.hd_path,x) for x in listdir(self.hd_path)]
        
    def __getitem__(self,index):
        image_name = self.ld_list[index].split('/')[-1]
        
        ld_data = Image.open(self.ld_list[index]).convert('RGB')
        w, h  = ld_data.size
        w -= w%self.scale_factor
        h -= h%self.scale_factor
        hd_scaler = Resize((self.scale_factor * mysize, self.scale_factor * mysize), interpolation=Image.BICUBIC)
        hd_restored = hd_scaler(ld_data)
        
        return image_name, ToTensor()(ld_data), ToTensor()(hd_restored), ToTensor()(hd_data)
    
    def __len__(self):
        return len(self.ld_list)
        
        

In [None]:
class DisplayTransformer():
    def __init__(self):
        super(DisplayTransformer,self).__init__()
        self.data_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((mysize,mysize)),
            transforms.ToTensor(),
        
        ])
        
    def __call__(self,img):
        return self.data_transform(img)

In [None]:
class HDTransformer():
    def __init__(self,mysize,mean,std):
        super(HDTransformer,self).__init__()
        self.hd_transform = Compose([
#             ToPILImage(), #need nparray
            Resize((mysize,mysize)),
            ToTensor(),
            Normalize(mean,std)
        ])
    def __call__(self, img):
        return self.hd_transform(img)

class LDTransformer():

    def __init__(self,mysize, scale_factor):
        super(LDTransformer,self).__init__()
        self.ld_transform = Compose([
            ToPILImage(),
#             Resize((mysize//scale_factor,mysize//scale_factor), interpolation=Image.BICUBIC),
            Resize((mysize//scale_factor,mysize//scale_factor), interpolation=Image.BICUBIC),
            ToTensor()
        ])
    def __call__(self, img):
        return self.ld_transform(img)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self,channels):
        super(ResidualBlock, self).__init__()
        self.conv2d_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2d_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(channels)
        
    def forward(self,x):
        residual = self.conv2d_1(x)
        residual = self.batch_norm1(residual)
        residual = self.prelu(residual)
        residual = self.conv2d_2(residual)
        residual = self.batch_norm2(residual)
        
        return x + residual

In [None]:
class UpsampleBlock(nn.Module):
    def __init__(self,in_channels, up_scale):
        super(UpsampleBlock, self).__init__()
        self.conv2d = nn.Conv2d(in_channels, in_channels * up_scale**2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()
        
    def forward(self, x):
        x = self.conv2d(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self,scale_factor):
        upsample_block_num = int(math.log(scale_factor,2))
        
        super(Generator, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        
        
        self.layer2 = ResidualBlock(64)
        self.layer3 = ResidualBlock(64)
        self.layer4 = ResidualBlock(64)
        self.layer5 = ResidualBlock(64)
        self.layer6 = ResidualBlock(64)
        self.layer7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64))
        
        layer8 = [UpsampleBlock(64,2) for _ in range(upsample_block_num)]
        layer8.append(nn.Conv2d(64,3, kernel_size=9, padding=4))
        self.layer8 = nn.Sequential(*layer8)
        
        
    def forward(self, x):
        layer1=self.layer1(x)
        layer2=self.layer2(layer1)
        layer3=self.layer3(layer2)
        layer4=self.layer4(layer3)
        layer5=self.layer5(layer4)
        layer6=self.layer6(layer5)
        layer7=self.layer7(layer6)
        layer8=self.layer8(layer1 + layer7)
        
        
        return (torch.tanh(layer8) + 1) /2
        
        

In [None]:
class AllSeeingEye(nn.Module):
    def __init__(self,z_dim=20,image_size=mysize):
        super(AllSeeingEye,self).__init__()
        
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    
        
    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))
    

In [None]:
def make_datapath_list(data_path):
    setpath = data_path
    imgpath = os.path.join(setpath,'*.png')
    
#     print(imgpath)
    path_list =[]
    
    for path in glob.glob(imgpath):
#         print(path)
        path_list.append(path)
#     print(path_list)
    return path_list

In [None]:
class TVLoss(nn.Module):
    def __init__(self,tv_loss_weight=1):
        super(TVLoss,self).__init__()
        self.tv_loss_weight = tv_loss_weight
        
    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        
        count_h = self.tensor_size(x[:,:,1:,:])
        count_w = self.tensor_size(x[:,:,:,1:])
        h_tv = torch.pow((x[:,:,1:,:] - x[:,:,:h_x - 1,:]), 2).sum()
        w_tv = torch.pow((x[:,:,:, 1:] - x[:,:,:, :w_x - 1]),2).sum()
        
        return self.tv_loss_weight * 2 * (h_tv/count_h + w_tv/ count_w)/batch_size
    
    @staticmethod
    def tensor_size(t):
        return t.size()[1]*t.size()[2]*t.size()[3]

class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss,self).__init__()
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()
        
    
    def forward(self, out_labels, out_images, target_images):
        #adversarial loss
        adversarial_loss = torch.mean(1-out_labels)
        #perception_loss
        #perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        #Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        
        tv_loss = self.tv_loss(out_images)
        
        return image_loss + 0.001 * adversarial_loss + 2e-8 * tv_loss # + 0.006 * perception_loss
        

In [None]:


import torch.nn.functional as F
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x- window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def ssim_conv(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size//2,groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
    
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2
    
    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2,groups=channel) -mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) -mu1_mu2
    
    C1 = 0.01 **2
    C2 = 0.03 **2
    
    ssim_map = ((2*mu1_mu2 + C1) * (2*sigma12 + C2))/((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    
    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)
    
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size,self.channel)
        
    def forward(self,img1, img2):
        (_, channel, _,_) =img1.size()
        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            img1.to(device)
            window.to(device)
            window = window.type_as(img1)
            self.window = window
            self.channel = channel
            
        return ssim_conv(img1, img2, window, self.window_size, channel, self.size_average)
    
def ssim(img1, img2, window_size=11,size_average=True):
    (_, channel,_,_) = img1.size()
    window = create_window(window_size, channel)
    
    img1.to(device)
    window.to(device)
    
    window = window.type_as(img1)
    
    return ssim_conv(img1, img2, window, window_size, channel, size_average)
    
        

In [None]:


mean = (0.5,)
std = (0.5,)

train_data_path = os.path.join(os.path.expanduser('~'),'downloads','trainsource')
val_data_path = os.path.join(os.path.expanduser('~'),'downloads','testsource')
out_path = os.path.join(os.path.expanduser('~'),'downloads','trainingresults')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

G = Generator(scale_factor)
D = AllSeeingEye()
G = G.to(device)
D = D.to(device)
train_set = TrainDatasetFromPath(train_data_path,scale_factor=scale_factor)
val_set = ValDatasetFromPath(val_data_path,scale_factor=scale_factor)

train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=batch_size, shuffle=True)
# train_loader = DataLoader(dataset=train_set,batch_size=batch_size, shuffle=True)
# batch_iterator = iter(train_loader)
# images = next(batch_iterator)

val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=batch_size, shuffle=True)

In [None]:
# train_set.image_list

In [None]:
# G = Generator(z_dim=20, image_size=36)
criterion = GeneratorLoss().to(device)

In [None]:
optimizer_g = optim.Adam(G.parameters())
optimizer_d = optim.Adam(D.parameters())
results={'d_loss':[],'g_loss':[],'d_score':[], 'g_score':[],'psnr':[],'ssim':[]}


In [None]:
# train_loader.dataset.image_list

In [None]:
def load_weigh():
    g_model_path = os.path.join(os.path.expanduser('~'),'onedrive\\g_srgan_model.pth')
    d_model_path = os.path.join(os.path.expanduser('~'),'onedrive\\d_srgan_model.pth')
    d_weights = torch.load(d_model_path)
    g_weights = torch.load(g_model_path)
    # d_weights = torch.load(d_model_path,map_location=torch.device('cpu'))
    # g_weights = torch.load(g_model_path,map_location=torch.device('cpu'))
    D.load_state_dict(d_weights)
    G.load_state_dict(g_weights)

In [None]:
if not os.path.exists(os.path.join(os.path.expanduser('~'),'downloads', "epochs")):
     os.makedirs(os.path.join(os.path.expanduser('~'),'downloads', "epochs"))
if not os.path.exists(os.path.join(os.path.expanduser('~'),'downloads', "statistics")):
     os.makedirs(os.path.join(os.path.expanduser('~'),'downloads', "statistics"))
if not os.path.exists(os.path.join(os.path.expanduser('~'),'downloads', "checkimg")):
     os.makedirs(os.path.join(os.path.expanduser('~'),'downloads', "checkimg"))
        
stats_path = os.path.join(os.path.expanduser('~'),'downloads', "statistics")
trained_path = os.path.join(os.path.expanduser('~'),'downloads', "epochs")
checkimg_path = os.path.join(os.path.expanduser('~'),'downloads', "checkimg")

In [None]:
load_weigh()

In [None]:
torch.backends.cudnn.benchmark = True
for epoch in range(1,EPOCH + 1):
    train_bar = tqdm(train_loader)
    running_results = {'batch_sizes':0,'d_loss':0,'g_loss':0,'d_score':0, 'g_score':0}
    G.train()
    D.train()
    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size


        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ###########################
#         print('running_results')
       
        real_img = Variable(target).to(device)
        z = Variable(data).to(device)
#         print('G(z) ')
        fake_img = G(z)
        D.zero_grad()

#         print('output real fake')
        real_out = D(real_img).mean()
        fake_out = D(fake_img).mean()
        d_loss = 1 - real_out + fake_out
        d_loss.backward(retain_graph=True)
        optimizer_d.step()


        ############################
#         (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################

        G.zero_grad()
        g_loss = criterion(fake_out, fake_img, real_img)
        g_loss.backward()
#         print('G.zero_grad')
#         g_loss = criterion(fake_out, fake_img, real_img)
        fake_img = G(z).to(device)
        fake_out = D(fake_img).mean()

        optimizer_g.step()
        running_results['d_loss'] += d_loss.item() * batch_size
        running_results['g_loss'] += g_loss.item() * batch_size
        running_results['d_score'] += real_out.item() * batch_size
        running_results['g_score'] += fake_out.item() * batch_size
#         print('running_results ',running_results)
        train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' %(
            epoch, EPOCH, 
            running_results['d_loss']/ running_results['batch_sizes'],
            running_results['g_loss'] / running_results['batch_sizes'],
            running_results['d_score'] / running_results['batch_sizes'],
            running_results['g_score'] / running_results['batch_sizes']))

    G.eval()
    out_path_fin = out_path + str(scale_factor) + '/'
    if not os.path.exists(out_path_fin):
        os.makedirs(out_path_fin)

    with torch.no_grad():
        val_bar = tqdm(val_loader)
        val_check_results = {'mse':0,'ssims':0,'psnr':0,'ssim':0,'batch_sizes':0}
        val_images = []
#         print('val_images:{}'.format(len(val_images)))


        for val_ld, val_hd_restore, val_hd in val_bar:
#             print('in val_bar')
            batch_size = val_ld.size(0)
            val_check_results['batch_sizes'] += batch_size

            ld = val_ld
            hd = val_hd

            ld = ld.to(device)
            hd = hd.to(device)
            
            sr = G(ld)

#             print(ld.size())
#             print(sr.size())
#             print(hd.size())

            batch_mse = ((sr - hd)**2).data.mean()
            val_check_results['mse'] += batch_mse * batch_size
            batch_ssim = ssim(sr, hd).item()
            val_check_results['ssims'] += batch_ssim * batch_size
            val_check_results['psnr'] = 10*math.log10(1/(val_check_results['mse']/val_check_results['batch_sizes']))
            val_check_results['ssim'] = val_check_results['ssims']/val_check_results['batch_sizes']
            val_bar.set_description(
                desc='[converting LD images to SR images] PSNR:%.4f dB SSIM: %.4f'%(
                    val_check_results['psnr'],val_check_results['ssim']))
#             print(val_hd_restore[:1,:,:,:].size())
#             print(val_hd_restore[:1,:,:,:].squeeze(0).size())
            val_images.extend(
                [DisplayTransformer()(val_hd_restore[:1,:,:,:].squeeze(0)), 
                 DisplayTransformer()(hd.data[:1,:,:,:].cpu().squeeze(0)),
                 DisplayTransformer()(sr.data[:1,:,:,:].cpu().squeeze(0))])
#         print('exit for loop')
        val_images = torch.stack(val_images)
        val_images = torch.chunk(val_images, val_images.size(0)//15)
        val_save_bar = tqdm(val_images, desc='[saving training results]')
        index = 1

        for image in val_save_bar:
            image = torchvision.utils.make_grid(image,nrow=3, padding = 5)
            torchvision.utils.save_image(image, checkimg_path + '/epoch_%d_index_%d.png'%(epoch, index), padding= 5)
            index += 1
            
    if epoch % 10 == 0 and epoch != 0:
        
        torch.save(G.state_dict(), trained_path + '/G_epoch_%d_%d.pth'%(scale_factor, epoch))
        torch.save(D.state_dict(), trained_path + '/D_epoch_%d_%d.pth'%(scale_factor, epoch))
    
    results['d_loss'].append(running_results['d_loss']/running_results['batch_sizes'])
    results['g_loss'].append(running_results['g_loss']/running_results['batch_sizes'])
    results['d_score'].append(running_results['d_score']/running_results['batch_sizes'])
    results['g_score'].append(running_results['g_score']/running_results['batch_sizes'])
    results['psnr'].append(val_check_results['psnr'])
    results['ssim'].append(val_check_results['ssim'])
    
    if epoch % 100 == 0 and epoch != 0:
        t = time.localtime()
        timestamp = time.strftime('%H%M%S', t)
        data_frame = pd.DataFrame(
            data={'Loss_D':results['d_loss'],'Loss_G':results['g_loss'],
                  'Score_D':results['d_score'],'Score_G':results['g_score'],
                 'PSNR':results['psnr'],'SSIM':results['ssim']},index=range(1,epoch+1))
        data_frame.to_csv(stats_path + '/srf_' + str(scale_factor) + timestamp +  '_train_results.csv', index_label='Epoch')
        
        



In [None]:
#save model

In [None]:
g_model_path = os.path.join(os.path.expanduser('~'),'onedrive\\g_srgan_model.pth')
d_model_path = os.path.join(os.path.expanduser('~'),'onedrive\\d_srgan_model.pth')
torch.save(D.state_dict(),d_model_path)
torch.save(G.state_dict(),g_model_path)

In [None]:
# D.eval()

In [None]:
# G.eval()

In [None]:
def weights_init(m):
    classname=m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data,0)

In [None]:
D.apply(weights_init)

In [None]:
G.apply(weights_init)

In [None]:
#exporting onnx model

In [None]:
exp_channel = 3
exp_height = 80
exp_width = 80
exp_small_batchsiz = 1
exp_path = os.path.join(os.path.expanduser('~'),'downloads\\onnx_srgan80.onnx')

In [None]:
# print(exp_path)

In [None]:
import torch
import torch.onnx
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# torch.backends.cudnn.benchmark = True
# exp_model = Generator(scale_factor).to(device)

input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

exp_g_weights =  torch.load(g_model_path)
G.load_state_dict(exp_g_weights)

dummy_input = torch.randn(exp_small_batchsiz, exp_channel, exp_height, exp_width).to(device)
torch.onnx.export(G, dummy_input, exp_path,export_params=True, verbose=True, input_names=input_names, output_names=output_names)

In [None]:
#exporting torchscript

In [None]:
#traced_script_module = torch.jit.trace(model, (captions, cap_lens, hidden), check_trace=False)


example = torch.rand(1, 3, 80, 80).to(device)
traced_script_module = torch.jit.trace(G, example,check_trace=False).to(device)
traced_script_module.save(os.path.join(os.path.expanduser('~'),"srgan_80.pt"))