In [1]:
import argparse
import os
import numpy as np
import math
import itertools
import datetime
import time
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch
import sys
import metric
import KL_flow as KL
import cv2
import argparse
import lpips.lpips as lpips
from pytorch_msssim import ssim,ms_ssim,SSIM,MS_SSIM
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models import *
from datasets import *
from utils import *
from laplacian_of_guassian import *
from flownet2.models import FlowNet2C  # the path is depended on where you create this module
from flownet2.utils.frame_utils import read_gen  # the path is depended on where you create this module
import time
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=20, help="number of epochs of training")
parser.add_argument("--train_data", type=str, default="/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/dataset/pickle_data/train_data.pkl", help="the path of pickle file about train data")
parser.add_argument("--val_data", type=str, default="/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/dataset/pickle_data/val_data.pkl", help="the path of pickle file about validation data")
parser.add_argument("--test_data", type=str, default="/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/dataset/pickle_data/test_data.pkl", help="the path of pickle file about testing data")
parser.add_argument("--batch_size", type=int, default=5, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0003, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=0, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=128, help="size of image height") #当图片的宽和高设置为256的时候会导致内存溢出
parser.add_argument("--img_width", type=int, default=128, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=2000, help="interval between saving generator outputs")
parser.add_argument("--checkpoint_interval", type=int, default=2, help="interval between saving model checkpoints")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")
parser.add_argument("--lambda_LoG", type=float, default=0.005, help="cycle loss weight")
parser.add_argument("--lambda_frame_GAN", type=float, default=0.003, help="identity loss weight")
parser.add_argument("--lambda_seq_GAN", type=float, default=0.003, help="identity loss weight")
parser.add_argument("--sequence_len", type=float, default=2, help="the original length of frame sequence(n+1)")
parser.add_argument("--save_model_path", type=str, default="saved_models/mm/", help="the path of saving the models")
parser.add_argument("--save_image_path", type=str, default="saved_images/mm/", help="the path of saving the images")
parser.add_argument("--log_file", type=str, default="log_mm.txt", help="the logging info of training")


_StoreAction(option_strings=['--log_file'], dest='log_file', nargs=None, const=None, default='log_mm.txt', type=<class 'str'>, choices=None, help='the logging info of training', metavar=None)

In [2]:
opt = parser.parse_args([])##返回一个命名空间,如果想要使用变量,可用args.attr
print(opt)

Namespace(b1=0.5, b2=0.999, batch_size=5, channels=3, checkpoint_interval=2, decay_epoch=0, epoch=0, img_height=128, img_width=128, lambda_LoG=0.005, lambda_frame_GAN=0.003, lambda_seq_GAN=0.003, log_file='log_mm.txt', lr=0.0003, n_cpu=0, n_epochs=20, n_residual_blocks=9, sample_interval=2000, save_image_path='saved_images/mm/', save_model_path='saved_models/mm/', sequence_len=2, test_data='/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/dataset/pickle_data/test_data.pkl', train_data='/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/dataset/pickle_data/train_data.pkl', val_data='/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/dataset/pickle_data/val_data.pkl')


In [3]:
os.makedirs(opt.save_image_path, exist_ok=True)
os.makedirs(opt.save_model_path, exist_ok=True)

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_Limage = torch.nn.L1Loss()
cuda = torch.cuda.is_available()

# Image transformations
transforms_ = [
    transforms.Resize((int(opt.img_height),int(opt.img_width)), Image.BICUBIC),
    transforms.ToTensor(),#Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
                          #Converts a PIL Image or numpy.ndarray (H x W x C) in the range
                          #[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),#[0,1] -> [-1,1]
    #transforms.Normalize([0.485, ], [0.229, ])
]

# Training data loader
dataloader = DataLoader(
    ImageTrainDataset(opt.train_data, transforms_=transforms_,nt=opt.sequence_len),
    batch_size=opt.batch_size,
    shuffle=False,
    num_workers=opt.n_cpu,
)

# val data loader
val_dataloader = DataLoader(
      ImageValDataset(opt.val_data, transforms_=transforms_,nt=opt.sequence_len),
      batch_size=1,
      shuffle=False,
      num_workers=opt.n_cpu,
  )

# test data loader
test_dataloader = DataLoader(
    ImageTestDataset(opt.test_data, transforms_=transforms_,nt=opt.sequence_len),
    batch_size=1,
    shuffle=False,
    num_workers=opt.n_cpu,
)


input_shape = (opt.batch_size, opt.sequence_len, opt.channels, opt.img_height, opt.img_width)
G_future = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_past = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = DiscriminatorA(input_shape)
D_B = DiscriminatorB(input_shape)

#loss_fn = lpips.LPIPS(net='alex')
Laplacian = Laplacian()

if cuda:
    G_future = torch.nn.DataParallel(G_future, device_ids=range(torch.cuda.device_count()))
    G_future = G_future.cuda()
    G_past = torch.nn.DataParallel(G_past, device_ids=range(torch.cuda.device_count()))
    G_past = G_past.cuda()
    D_A = torch.nn.DataParallel(D_A, device_ids=range(torch.cuda.device_count()))
    D_A = D_A.cuda()
    D_B = torch.nn.DataParallel(D_B, device_ids=range(torch.cuda.device_count()))
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_Limage.cuda()
    Laplacian = Laplacian.cuda()
    #loss_fn = loss_fn.cuda()
if opt.epoch != 0:
    # Load pretrained models
    G_future.load_state_dict(torch.load(opt.save_model_path+"G_future_%d.pth" %  opt.epoch))
    G_past.load_state_dict(torch.load(opt.save_model_path+"G_past_%d.pth" %  opt.epoch))
    D_A.load_state_dict(torch.load(opt.save_model_path+"D_A_%d.pth" %  opt.epoch))
    D_B.load_state_dict(torch.load(opt.save_model_path+"D_B_%d.pth" %  opt.epoch))
else:
    # Initialize weights
    G_future.apply(weights_init_normal)#apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上
    G_past.apply(weights_init_normal)#apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)


# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_future.parameters(),G_past.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step# lr_lambda为操作学习率的函数
)

lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

#保存中间的训练结果
def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    imgs = imgs.type(Tensor)
    input_A = imgs[:,:-1,...]
    input_A = input_A.view((imgs.size(0),-1,)+imgs.size()[3:])
    G_future.eval()
    real_A = Variable(imgs[:,-1:,...])
    A1 = G_future(input_A)
    frames = torch.cat((imgs[0,:-1,],A1[0].unsqueeze(0),imgs[0]), 0)
    image_grid = make_grid(frames,nrow=opt.sequence_len,normalize=False)
    save_image(image_grid, opt.save_image_path+"fake_%s.png" % (batches_done), normalize=False)

def ReverseSeq(Seq):
    length = Seq.size(1)
    return torch.cat([Seq[:,i-2:i+1,...] for i in range(length-1,-1,-3)],1)

In [4]:
import matplotlib.pyplot as pyplot
import pylab

In [5]:
count = 0
#################加载Flownet2.0模型
parser0 = argparse.ArgumentParser()
parser0.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).')
parser0.add_argument("--rgb_max", type=float, default=255.)
args0 = parser0.parse_args([])
flownet = FlowNet2C(args0).cuda()
    # load the state_dict
dict = torch.load("./flownet2/FlowNet2-C_checkpoint.pth.tar")
flownet.load_state_dict(dict["state_dict"])
# loss_fn = lpips.LPIPS(net='alex')
#################加载Flownet2.0模型
prev_time = time.time()

In [6]:
for i, frame_seq in enumerate(dataloader):
    print(i,frame_seq.shape)
    if i==2:
        break

0 torch.Size([5, 2, 3, 128, 128])
1 torch.Size([5, 2, 3, 128, 128])
2 torch.Size([5, 2, 3, 128, 128])


In [7]:
frame_seq = frame_seq.type(Tensor)

real_A = Variable(frame_seq[:,-1,...]) #[bs,1,c,h,w]
input_A = Variable(frame_seq[:,:-1,...].view((frame_seq.size(0),-1)+frame_seq.size()[3:]))
real_B = Variable(frame_seq[:,0,...]) #[bs,1,c,h,w]
input_B_ = Variable(frame_seq[:,1:,...].view((frame_seq.size(0),-1)+frame_seq.size()[3:]))
input_B = ReverseSeq(input_B_)

# Adversarial ground truths
valid = Variable(Tensor(np.ones((frame_seq.size(0), *D_A.module.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((frame_seq.size(0), *D_A.module.output_shape))), requires_grad=False)

#------------------------
#  Train Generator
#------------------------
G_future.train()
G_past.train()        
optimizer_G.zero_grad()#梯度清零
print('input_A',input_A.shape)
print('input_B',input_B.shape)
#L_Image loss which minimize the L1 Distance between the image pair
A1 = G_future(input_A) # x^'_{n}  generated future frame
B1 = G_past(input_B) # x^'_{m} generated past frame
print(B1.shape)

input_A torch.Size([5, 3, 128, 128])
input_B torch.Size([5, 3, 128, 128])
torch.Size([5, 3, 128, 128])


In [8]:
for j in range(4):
                prev = A1[j].cpu().clone()
                curr = A1[j+1].cpu().clone()
                image_grid_prev = prev
                save_image(image_grid_prev, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_prev.png" % (j), normalize=False)
                image_grid_curr = curr
                save_image(image_grid_curr, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_curr.png" % (j), normalize=False)  
    
                prev_true = real_A[j].cpu().clone()
                curr_true = real_A[j+1].cpu().clone()
                image_grid_prev_true = prev_true
                save_image(image_grid_prev_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_prev.png" % (j), normalize=False)
                image_grid_curr_true = curr_true
                save_image(image_grid_curr_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_curr.png" % (j), normalize=False)  

In [9]:
optical_loss_0 = []
optical_loss_1 = []

In [10]:
for j in range(4):
                pim1_fake = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_prev.png" % (j))
                pim2_fake = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_curr.png" % (j))
                images_fake = [pim1_fake, pim2_fake]
                images_fake = np.array(images_fake).transpose(3, 0, 1, 2)
                im_fake = torch.from_numpy(images_fake.astype(np.float32)).unsqueeze(0).cuda()
                result_fake = flownet(im_fake).squeeze()
    
                pim1_real = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_prev.png" % (j))
                pim2_real = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_curr.png" % (j))
                images_real = [pim1_real, pim2_real]
                images_real = np.array(images_real).transpose(3, 0, 1, 2)
                im_real = torch.from_numpy(images_real.astype(np.float32)).unsqueeze(0).cuda()
                result_real= flownet(im_real).squeeze()
    
                loss_opt_0 = criterion_Limage(result_fake[0],result_real[0]) 
                optical_loss_0.append(loss_opt_0)
    
                loss_opt_1 = criterion_Limage(result_fake[1],result_real[1])

  "See the documentation of nn.Upsample for details.".format(mode))


In [11]:
optical_loss_total = (sum(optical_loss_0) / 4 + sum(optical_loss_1) / 4 )/2

       
        #############################################
        #                                           #
        #                                           #
        #        Optical flow Loss Function         #  
        #                                           # 
        #                                           # 
        #############################################
input_A_A1_ = torch.cat((input_A[:,3:,...],A1),1)
input_A_A1 = ReverseSeq(input_A_A1_)
input_B_B1 = torch.cat((B1,input_A[:,3:,...]),1)
A11 = G_future(input_B_B1)# x^''_{n}
B11 = G_past(input_A_A1)

loss_A_A1 = criterion_Limage(real_A,A1)
loss_A_A11 = criterion_Limage(real_A,A11)
loss_A1_A11 = criterion_Limage(A1,A11)
loss_B_B1 = criterion_Limage(real_B,B1)
loss_B_B11 = criterion_Limage(real_B,B11)
loss_B1_B11 = criterion_Limage(B1,B11)

loss_Image = (loss_A_A1 + loss_A_A11 + loss_A1_A11 + loss_B_B1 + loss_B_B11 + loss_B1_B11 ) / 6

print('real_A',real_A.shape)
print('A11',A11.shape)
print('A1',A1.shape)
print('real_B',real_B.shape)
print('B11',B11.shape)
print('B1',B1.shape)
#L_LoG loss 
L_LoG_A_A1 = criterion_Limage(Laplacian(real_A),Laplacian(A1))
L_LoG_A_A11 = criterion_Limage(Laplacian(real_A),Laplacian(A11))
L_LoG_A1_A11 = criterion_Limage(Laplacian(A1),Laplacian(A11))

L_LoG_B_B1 = criterion_Limage(Laplacian(real_B),Laplacian(B1))
L_LoG_B_B11 = criterion_Limage(Laplacian(real_B),Laplacian(B11))
L_LoG_B1_B11 = criterion_Limage(Laplacian(B1),Laplacian(B11))


real_A torch.Size([5, 3, 128, 128])
A11 torch.Size([5, 3, 128, 128])
A1 torch.Size([5, 3, 128, 128])
real_B torch.Size([5, 3, 128, 128])
B11 torch.Size([5, 3, 128, 128])
B1 torch.Size([5, 3, 128, 128])


In [12]:
for i, frame_seq_test in tqdm(enumerate(val_dataloader)):
    print(i)
    if i==2:
        break

0it [00:00, ?it/s]

0
1
2





In [13]:
num = 0
tatal_PSNR = 0
total_SSIM = 0
total_MSE = 0
total_MSSSIM = 0
# total_LPIPS = 0        
ms_ssim_module=MS_SSIM(win_size=7,win_sigma=1.5,data_range=1,size_average=True,channel=3)

frame_seq_test = frame_seq_test.type(Tensor)
real_A_test = Variable(frame_seq_test[:,-1,...]) #[bs,1,c,h,w]
input_A = Variable(frame_seq_test[:,:-1,...].view((frame_seq_test.size(0),-1)+frame_seq_test.size()[3:]))
# A1 = G_future(input_A)

In [14]:
A1 = G_future(input_A)

In [16]:
num = 0
tatal_PSNR = 0
total_SSIM = 0
total_MSE = 0
total_MSSSIM = 0
# total_LPIPS = 0        
ms_ssim_module=MS_SSIM(win_size=7,win_sigma=1.5,data_range=1,size_average=True,channel=3)
for i, frame_seq_test in enumerate(test_dataloader):
            frame_seq_test = frame_seq_test.type(Tensor)
            real_A_test = Variable(frame_seq_test[:,-1,...]) #[bs,1,c,h,w]
            input_A = Variable(frame_seq_test[:,:-1,...].view((frame_seq_test.size(0),-1)+frame_seq_test.size()[3:]))
            A1 = G_future(input_A)
            print('real_A_test',type(real_A_test))
            print('A1',type(A1))
            num += 1
            psnr = metric.PSNR(real_A_test.squeeze(0).detach().cpu().clone().numpy(), A1.squeeze(0).detach().cpu().clone().numpy())
            ssim = metric.SSIM(real_A_test.squeeze(0).detach().cpu().clone().numpy(), A1.squeeze(0).detach().cpu().clone().numpy())
            mse = metric.MSE(real_A_test.squeeze(0).detach().cpu().clone().numpy(), A1.squeeze(0).detach().cpu().clone().numpy())*1000     
            ms_ssim_loss=ms_ssim_module(real_A_test.detach(),A1.detach())           
#             lpips =  loss_fn.forward(real_A_test.squeeze(0).cuda(), A1.squeeze(0).cuda())    
            
            tatal_PSNR += psnr
            total_SSIM += ssim
            total_MSE += mse
            total_MSSSIM += ms_ssim_loss
#             total_LPIPS += lpips
#             break

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torch.Tensor'>
A1 <class 'torch.Tensor'>
real_A_test <class 'torc

KeyboardInterrupt: 

In [31]:
real_A_test.shape

torch.Size([5, 3, 128, 128])

In [33]:
type(real_A_test.squeeze(0))

torch.Tensor

In [11]:
a = "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/dataset/pickle_data/train_data.pkl"

In [12]:
video_pkl_file = a

train_pkl_file = open(video_pkl_file, 'rb')
data = pickle.load(train_pkl_file)

In [10]:
# normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# train_transform = transforms.Compose([
#     transforms.Resize((int(opt.img_height),int(opt.img_width)), Image.BICUBIC),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])#[0,1] -> [-1,1]
# transforms_ = [
#     transforms.Resize((int(opt.img_height),int(opt.img_width)), Image.BICUBIC),
#     transforms.ToTensor(),#Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
#                           #Converts a PIL Image or numpy.ndarray (H x W x C) in the range
#                           #[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),#[0,1] -> [-1,1]
# ]  
# train_transform0 = transforms.Compose(transforms_)

In [10]:
# # /home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/tinyjpg/train/train_jpg
# # ImageFolder(os.path.join('/home/ubd/ganomaly/umntotal/streak_u_add/U/scene3/', x), transform)
# # ROOT_TRAIN = '/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/tinyjpg/test/test0/'
# target = '/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/tinyjpg/tinyjpg/train/train/train_jpg/'
# imagedir = os.listdir(target)
# imagedir.sort()


In [11]:
# start = 0
# end = 5
# frame_seq = []
# for img_name in imagedir[start:end]:
#             img = Image.open(target+img_name)
#             img = train_transform0(img)
#             frame_seq.append(img)
# frame_seq = torch.stack(frame_seq, 0)

In [12]:
# from torchvision.datasets import ImageFolder

In [13]:
# dataset = ImageFolder(target, train_transform0) 
# dataloader = torch.utils.data.DataLoader(dataset,
#                                                      batch_size=5,
#                                                      shuffle=False,
#                                                      num_workers=int(8),
#                                                      )

In [14]:
# for i, frame_seq in enumerate(dataloader):
#     print(frame_seq[0].shape)
#     break

In [15]:
# frame_seq = []
# for i in range(10):
#         start = 0
#         end = 5
#         for img_name in imagedir[start:end]:
#             img = Image.open(target+img_name)
#             img = train_transform(img)
#             frame_seq.append(img)
#         frame_seq = torch.stack(frame_seq, 0)
#         print(frame_seq.shape)
#         count += 1
#         frame_seq = frame_seq.type(Tensor)
#         print(frame_seq.shape)
#         real_A = Variable(frame_seq[:,-1,...]) #[bs,1,c,h,w]
#         input_A = Variable(frame_seq[:,:-1,...].view((frame_seq.size(0),-1)+frame_seq.size()[3:]))
#         real_B = Variable(frame_seq[:,0,...]) #[bs,1,c,h,w]
#         input_B_ = Variable(frame_seq[:,1:,...].view((frame_seq.size(0),-1)+frame_seq.size()[3:]))
#         input_B = ReverseSeq(input_B_)

#         # Adversarial ground truths
#         valid = Variable(Tensor(np.ones((frame_seq.size(0), *D_A.module.output_shape))), requires_grad=False)
#         fake = Variable(Tensor(np.zeros((frame_seq.size(0), *D_A.module.output_shape))), requires_grad=False)

#         #------------------------
#         #  Train Generator
#         #------------------------
#         G_future.train()
#         G_past.train()
#         optimizer_G.zero_grad()#梯度清零

# #         #L_Image loss which minimize the L1 Distance between the image pair
#         A1 = G_future(input_A) # x^'_{n}  generated future frame   ############
#         B1 = G_past(input_B) # x^'_{m} generated past frame
#         input_A_A1_ = torch.cat((input_A[:,3:,...],A1),1)
#         input_A_A1 = ReverseSeq(input_A_A1_)
#         input_B_B1 = torch.cat((B1,input_A[:,3:,...]),1)
#         A11 = G_future(input_B_B1)# x^''_{n}
#         B11 = G_past(input_A_A1)# x^''_{m}
#         start = end
#         end = end*2
# #         frame_seq = []
#         break

In [16]:
import torch
import numpy as np
import argparse

In [17]:
from flownet2.models import FlowNet2C  # the path is depended on where you create this module
from flownet2.utils.frame_utils import read_gen  # the path is depended on where you create this module

In [18]:
parser0 = argparse.ArgumentParser()
parser0.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).')
parser0.add_argument("--rgb_max", type=float, default=255.)
args0 = parser0.parse_args([])
flownet = FlowNet2C(args0).cuda()
    # load the state_dict
dict = torch.load("./flownet2/FlowNet2-C_checkpoint.pth.tar")
flownet.load_state_dict(dict["state_dict"])

<All keys matched successfully>

In [32]:
# for j in range(4):
prev = A1[0]
curr = A1[0+1]
# image_grid_prev = make_grid(prev,nrow=1,normalize=False)
image_grid_prev = prev.squeeze(0).detach().cpu().clone()
save_image(image_grid_prev, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_prev.png" % (4), normalize=False)
image_grid_curr = curr.squeeze(0).detach().cpu().clone()
save_image(image_grid_curr, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_curr.png" % (4), normalize=False)  
    
#     prev_true = real_A[j]
#     curr_true = real_A[j+1]
#     image_grid_prev_true = make_grid(prev_true,nrow=1,normalize=False)
#     save_image(image_grid_prev_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_prev.png" % (j), normalize=False)
#     image_grid_curr_true = make_grid(curr_true,nrow=1,normalize=False)
#     save_image(image_grid_curr_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_curr.png" % (j), normalize=False)   

In [73]:
# for j in range(4): 
#     prev_true = real_A[j].cpu().clone()
#     curr_true = real_A[j+1].cpu().clone()
#     image_grid_prev_true = prev_true
#     save_image(image_grid_prev_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true0_%s0_prev.png" % (j), normalize=False)
#     image_grid_curr_true = curr_true
#     save_image(image_grid_curr_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true0_%s0_curr.png" % (j), normalize=False)  

In [79]:
j=0

In [80]:
prev = A1[j].cpu().clone()
curr = A1[j+1].cpu().clone()
image_grid_prev = prev
save_image(image_grid_prev, "fake_%s_prev0.png" % (j), normalize=False)
image_grid_curr = curr
save_image(image_grid_curr, "fake_%s_curr0.png" % (j), normalize=False)  

In [81]:
prev_true = real_A[j].cpu().clone()
curr_true = real_A[j+1].cpu().clone()
image_grid_prev_true = prev_true
save_image(image_grid_prev_true, "true_%s_prev0.png" % (j), normalize=False)
image_grid_curr_true = curr_true
save_image(image_grid_curr_true, "true_%s_curr0.png" % (j), normalize=False)      

In [82]:
prev_true = real_A[j]
curr_true = real_A[j+1]
image_grid_prev_true = make_grid(prev_true,nrow=1,normalize=False)
save_image(image_grid_prev_true, "0r.png", normalize=False)
image_grid_curr_true = make_grid(curr_true,nrow=1,normalize=False)
save_image(image_grid_curr_true, "0r.png" , normalize=False)

In [83]:
prev_true = A1[j]
curr_true = A1[j+1]
image_grid_prev_true = make_grid(prev_true,nrow=1,normalize=False)
save_image(image_grid_prev_true, "0f.png", normalize=False)
image_grid_curr_true = make_grid(curr_true,nrow=1,normalize=False)
save_image(image_grid_curr_true, "0f.png" , normalize=False)

In [78]:
real_A[0].shape

torch.Size([3, 128, 128])

In [77]:
A1[0].shape

torch.Size([3, 128, 128])

In [19]:
torch.__version__

'1.3.1'

In [76]:
for i, frame_seq in enumerate(dataloader):
        count += 1
        optical_loss_0 = 0
        optical_loss_1 = 0
        frame_seq = frame_seq.type(Tensor)

        real_A = Variable(frame_seq[:,-1,...]) #[bs,1,c,h,w]
        input_A = Variable(frame_seq[:,:-1,...].view((frame_seq.size(0),-1)+frame_seq.size()[3:]))
        real_B = Variable(frame_seq[:,0,...]) #[bs,1,c,h,w]
        input_B_ = Variable(frame_seq[:,1:,...].view((frame_seq.size(0),-1)+frame_seq.size()[3:]))
        input_B = ReverseSeq(input_B_)

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((frame_seq.size(0), *D_A.module.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((frame_seq.size(0), *D_A.module.output_shape))), requires_grad=False)

        #------------------------
        #  Train Generator
        #------------------------
        G_future.train()
        G_past.train()        
        optimizer_G.zero_grad()#梯度清零

        #L_Image loss which minimize the L1 Distance between the image pair
        A1 = G_future(input_A) # x^'_{n}  generated future frame
        B1 = G_past(input_B) # x^'_{m} generated past frame
        input_A_A1_ = torch.cat((input_A[:,3:,...],A1),1)
        input_A_A1 = ReverseSeq(input_A_A1_)
        input_B_B1 = torch.cat((B1,input_A[:,3:,...]),1)
        A11 = G_future(input_B_B1)# x^''_{n}
        B11 = G_past(input_A_A1)

        input_A_A1_ = torch.cat((input_A[:,3:,...],A1),1)
        input_A_A1 = ReverseSeq(input_A_A1_)
        input_B_B1 = torch.cat((B1,input_A[:,3:,...]),1)
        A11 = G_future(input_B_B1)# x^''_{n}
        B11 = G_past(input_A_A1)

        loss_A_A1 = criterion_Limage(real_A,A1)
        loss_A_A11 = criterion_Limage(real_A,A11)
        loss_A1_A11 = criterion_Limage(A1,A11)
        loss_B_B1 = criterion_Limage(real_B,B1)
        loss_B_B11 = criterion_Limage(real_B,B11)
        loss_B1_B11 = criterion_Limage(B1,B11)

        loss_Image = (loss_A_A1 + loss_A_A11 + loss_A1_A11 + loss_B_B1 + loss_B_B11 + loss_B1_B11 ) / 6

        
        #L_LoG loss 
        L_LoG_A_A1 = criterion_Limage(Laplacian(real_A),Laplacian(A1))
        L_LoG_A_A11 = criterion_Limage(Laplacian(real_A),Laplacian(A11))
        L_LoG_A1_A11 = criterion_Limage(Laplacian(A1),Laplacian(A11))
        L_LoG_B_B1 = criterion_Limage(Laplacian(real_B),Laplacian(B1))
        L_LoG_B_B11 = criterion_Limage(Laplacian(real_B),Laplacian(B11))
        L_LoG_B1_B11 = criterion_Limage(Laplacian(B1),Laplacian(B11))

        loss_LoG = (L_LoG_A_A1 + L_LoG_A_A11 + L_LoG_A1_A11 + L_LoG_B_B1 + L_LoG_B_B11 + L_LoG_B1_B11) / 6

        #GAN frame Loss(Least Square Loss)
        loss_frame_GAN_A1  = criterion_GAN(D_A(A1),valid)
        loss_frame_GAN_B1  = criterion_GAN(D_A(B1),valid)
        loss_frame_GAN_A11 = criterion_GAN(D_A(A11),valid)
        loss_frame_GAN_B11 = criterion_GAN(D_A(B11),valid)
        #Total frame loss
        loss_frame_GAN = (loss_frame_GAN_A1 + loss_frame_GAN_B1 + loss_frame_GAN_A11 + loss_frame_GAN_B11) / 4

        #GAN seq Loss 
        #four kinds of the synthetic frame sequence
        input_B1_A  = torch.cat((B1,input_A[:,3:,...],real_A),1)
        input_B11_A1 = torch.cat((B11,input_A[:,3:,...],A1),1)
        input_B_A1 = torch.cat((real_B,input_A[:,3:,...],A1),1)
        input_B1_A11 = torch.cat((B1,input_A[:,3:,...],A11),1)
        loss_seq_GAN_B1_A = criterion_GAN(D_B(input_B1_A),valid)
        loss_seq_GAN_B11_A1 = criterion_GAN(D_B(input_B11_A1),valid)
        loss_seq_GAN_B_A1 = criterion_GAN(D_B(input_B_A1),valid)
        loss_seq_GAN_B1_A11 = criterion_GAN(D_B(input_B1_A11),valid)
                
        # Total seq loss
        loss_seq_GAN = (loss_seq_GAN_B1_A + loss_seq_GAN_B11_A1 + loss_seq_GAN_B_A1 + loss_seq_GAN_B1_A11) / 4
        # Total GAN loss
        total_loss_GAN = loss_Image + opt.lambda_LoG*loss_LoG + opt.lambda_frame_GAN *loss_frame_GAN + opt.lambda_seq_GAN*loss_seq_GAN        
        

        for j in range(4):
            prev = A1[j].cpu().clone()
            curr = A1[j+1].cpu().clone()
            image_grid_prev = prev
            save_image(image_grid_prev, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_prev.png" % (j), normalize=False)
            image_grid_curr = curr
            save_image(image_grid_curr, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_curr.png" % (j), normalize=False)  
    
            prev_true = real_A[j].cpu().clone()
            curr_true = real_A[j+1].cpu().clone()
            image_grid_prev_true = prev_true
            save_image(image_grid_prev_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_prev.png" % (j), normalize=False)
            image_grid_curr_true = curr_true
            save_image(image_grid_curr_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_curr.png" % (j), normalize=False)      
        
        for j in range(4):
            pim1_fake = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_prev.png" % (j))
            pim2_fake = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_curr.png" % (j))
            images_fake = [pim1_fake, pim2_fake]
            images_fake0 = np.array(images_fake).transpose(3, 0, 1, 2)
            im_fake = torch.from_numpy(images_fake0.astype(np.float32)).unsqueeze(0).cuda()
            result_fake = flownet(im_fake).squeeze()
    
            pim1_real = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_prev.png" % (j))
            pim2_real = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_curr.png" % (j))
            images_real = [pim1_real, pim2_real]
            images_real1 = np.array(images_real).transpose(3, 0, 1, 2)
            im_real = torch.from_numpy(images_real1.astype(np.float32)).unsqueeze(0).cuda()
            result_real= flownet(im_real).squeeze()
    
            loss_A_A1_opt_0 = criterion_Limage(result_fake[0],result_real[0]) 
            optical_loss_0 += (loss_A_A1_opt_0) 
    
            loss_A_A1_opt_1 = criterion_Limage(result_fake[1],result_real[1])
            optical_loss_1 +=(loss_A_A1_opt_1)  
             
        optical_loss_total = (optical_loss_0/4 + optical_loss_1/4 )/2 
        
        total_loss_GAN += optical_loss_total
        
        total_loss_GAN.backward()
        optimizer_G.step()
        break

  "See the documentation of nn.Upsample for details.".format(mode))


In [None]:
# optical_loss_0 = []
# optical_loss_1 = []
# #         #optical_loss_0_0 = []
# #         #optical_loss_1_1 = []
# # for j in range(4):
# #     prev = A1[j]
# #     curr = A1[j+1]
# #     image_grid_prev = make_grid(prev,nrow=1,normalize=False)
# #     save_image(image_grid_prev, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_prev.png" % (j), normalize=False)
# #     image_grid_curr = make_grid(curr,nrow=1,normalize=False)
# #     save_image(image_grid_curr, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_curr.png" % (j), normalize=False)  
    
# #     prev_true = real_A[j]
# #     curr_true = real_A[j+1]
# #     image_grid_prev_true = make_grid(prev,nrow=1,normalize=False)
# #     save_image(image_grid_prev_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_prev.png" % (j), normalize=False)
# #     image_grid_curr_true = make_grid(curr,nrow=1,normalize=False)
# #     save_image(image_grid_curr_true, "/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_curr.png" % (j), normalize=False)  
    
# # optical_loss_0 = []
# # optical_loss_1 = []
        
# for j in range(4):
#     pim1_fake = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_prev.png" % (j))
#     pim2_fake = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/fake_jpg/fake_%s_curr.png" % (j))
#     images_fake = [pim1_fake, pim2_fake]
#     images_fake = np.array(images_fake).transpose(3, 0, 1, 2)
#     im_fake = torch.from_numpy(images_fake.astype(np.float32)).unsqueeze(0).cuda()
#     result_fake = net(im_fake).squeeze()
    
#     pim1_real = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_prev.png" % (j))
#     pim2_real = read_gen("/home/ubd/EMANet-master/Video_Prediction_ZOO-master/RetrospectiveCycleGAN/true_jpg/true_%s_curr.png" % (j))
#     images_real = [pim1_real, pim2_real]
#     images_real = np.array(images_real).transpose(3, 0, 1, 2)
#     im_real = torch.from_numpy(images_real.astype(np.float32)).unsqueeze(0).cuda()
#     result_real= net(im_real).squeeze()
    
#     loss_A_A1_opt_0 = criterion_Limage(result_fake[0],result_real[0]) 
#     optical_loss_0.append(loss_A_A1_opt_0) 
    
#     loss_A_A1_opt_1 = criterion_Limage(result_fake[1],result_real[1])
#     optical_loss_1.append(loss_A_A1_opt_1)    
# #     break
    

In [None]:
# o_l_0 = (sum(optical_loss_0)) 
# o_l_1 = (sum(optical_loss_1)) 
#         #o_l_0_0 = (sum(optical_loss_0_0)) 
#         #o_l_1_1 = (sum(optical_loss_1_1))   
# optical_loss_total = (o_l_0/4 + o_l_1/4 )/2

In [None]:
# input_A_A1_ = torch.cat((input_A[:,3:,...],A1),1)
# input_A_A1 = ReverseSeq(input_A_A1_)
# input_B_B1 = torch.cat((B1,input_A[:,3:,...]),1)
# A11 = G_future(input_B_B1)# x^''_{n}
# B11 = G_past(input_A_A1)

# loss_A_A1 = criterion_Limage(real_A,A1)
# loss_A_A11 = criterion_Limage(real_A,A11)
# loss_A1_A11 = criterion_Limage(A1,A11)
# loss_B_B1 = criterion_Limage(real_B,B1)
# loss_B_B11 = criterion_Limage(real_B,B11)
# loss_B1_B11 = criterion_Limage(B1,B11)

# loss_Image = (loss_A_A1 + loss_A_A11 + loss_A1_A11 + loss_B_B1 + loss_B_B11 + loss_B1_B11 ) / 6
# loss_Image += optical_loss_total
        
#         #L_LoG loss 
# L_LoG_A_A1 = criterion_Limage(Laplacian(real_A),Laplacian(A1))
# L_LoG_A_A11 = criterion_Limage(Laplacian(real_A),Laplacian(A11))
# L_LoG_A1_A11 = criterion_Limage(Laplacian(A1),Laplacian(A11))
# L_LoG_B_B1 = criterion_Limage(Laplacian(real_B),Laplacian(B1))
# L_LoG_B_B11 = criterion_Limage(Laplacian(real_B),Laplacian(B11))
# L_LoG_B1_B11 = criterion_Limage(Laplacian(B1),Laplacian(B11))

# loss_LoG = (L_LoG_A_A1 + L_LoG_A_A11 + L_LoG_A1_A11 + L_LoG_B_B1 + L_LoG_B_B11 + L_LoG_B1_B11) / 6
# #         loss_LoG.backward()
#         #GAN frame Loss(Least Square Loss)
# loss_frame_GAN_A1  = criterion_GAN(D_A(A1),valid)# lead the synthetic frame become similiar to the real frame
# loss_frame_GAN_B1  = criterion_GAN(D_A(B1),valid)
# loss_frame_GAN_A11 = criterion_GAN(D_A(A11),valid)
# loss_frame_GAN_B11 = criterion_GAN(D_A(B11),valid)
#         #Total frame loss
# loss_frame_GAN = (loss_frame_GAN_A1 + loss_frame_GAN_B1 + loss_frame_GAN_A11 + loss_frame_GAN_B11) / 4
#         # print("Frame Loss Done")
# #         loss_frame_GAN.backward()
#         #GAN seq Loss 
#         #four kinds of the synthetic frame sequence
# input_B1_A  = torch.cat((B1,input_A[:,3:,...],real_A),1)
# input_B11_A1 = torch.cat((B11,input_A[:,3:,...],A1),1)
# input_B_A1 = torch.cat((real_B,input_A[:,3:,...],A1),1)
# input_B1_A11 = torch.cat((B1,input_A[:,3:,...],A11),1)
# loss_seq_GAN_B1_A = criterion_GAN(D_B(input_B1_A),valid)
# loss_seq_GAN_B11_A1 = criterion_GAN(D_B(input_B11_A1),valid)
# loss_seq_GAN_B_A1 = criterion_GAN(D_B(input_B_A1),valid)
# loss_seq_GAN_B1_A11 = criterion_GAN(D_B(input_B1_A11),valid)
#         # Total seq loss
# loss_seq_GAN = (loss_seq_GAN_B1_A + loss_seq_GAN_B11_A1 + loss_seq_GAN_B_A1 + loss_seq_GAN_B1_A11) / 4
# #         loss_seq_GAN.backward()
#         # Total GAN loss
# total_loss_GAN = loss_Image + opt.lambda_LoG*loss_LoG + opt.lambda_frame_GAN *loss_frame_GAN + opt.lambda_seq_GAN*loss_seq_GAN
# total_loss_GAN.backward()