### GAN with least square adversarial loss

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import random
import time
from models_v2 import ConvLSTM, PhyCell, EncoderRNN
from constrain_moments import K2M
import os
import h5py
import torch.utils.data as data
import math
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import datetime
import itertools

#### create the training, validation set for PhyDNet model testing on Sky image dataset

In [None]:
# day block shuffling of the time stamps, and return shuffled indices
def day_block_shuffle(times_trainval):
    
    # Only keep the date of each time point
    dates_trainval = np.zeros_like(times_trainval, dtype=datetime.date)
    for i in range(len(times_trainval)):
        dates_trainval[i] = times_trainval[i].date()

    # Chop the indices into blocks, so that each block contains the indices of the same day
    unique_dates = np.unique(dates_trainval)
    blocks = []
    for i in range(len(unique_dates)):
        blocks.append(np.where(dates_trainval == unique_dates[i])[0])

    # shuffle the blocks, and chain it back together
    np.random.seed(1)
    np.random.shuffle(blocks)
    shuffled_indices = np.asarray(list(itertools.chain.from_iterable(blocks)))

    return shuffled_indices

In [None]:
# Spliting the dayblock shuffled indices into training and validation
def trainval_split(split_data, split_ratio):
    '''
    input:
    split_data: the dayblock shuffled indices to be splitted
    fold_index: the ith fold chosen as the validation, used for generating the seed for random shuffling
    num_fold: N-fold cross validation
    output:
    data_train: the train data indices
    data_val: the validation data indices
    '''
    # randomly divides into a training set and a validation set
    num_samples = len(split_data[0])
    indices = np.arange(num_samples)

    # finding training and validation indices
    val_mask = np.zeros(len(indices), dtype=bool)
    val_mask[:int(split_ratio * num_samples)] = True
    val_indices = indices[val_mask]
    train_indices = indices[np.logical_not(val_mask)]

    # shuffle indices
    np.random.seed(0)
    np.random.shuffle(train_indices)
    np.random.shuffle(val_indices)
    
    # Initialize the training and validation data set list
    data_train = []
    data_val = []
    # obtain training and validation data
    for one_data in split_data:
        one_train, one_val = one_data[train_indices], one_data[val_indices]
        data_train.append(one_train)
        data_val.append(one_val)

    return data_train,data_val

In [None]:
cwd = os.getcwd()
pardir = os.path.dirname(os.path.dirname(cwd))
data_folder = os.path.join(pardir,'data')
data_path = os.path.join(data_folder,'video_prediction_dataset.hdf5')
print("data_folder:", data_folder)
print("data_path:", data_path)

In [None]:
with h5py.File(data_path,'r') as f:
    trainval = f['trainval']
    images_log_train = trainval['images_log'][...][:,::2,:,:,:]
    images_pred_train = trainval['images_pred'][...][:,::2,:,:,:]
    
    test = f['test']
    images_log_test = test['images_log'][...][:,::2,:,:,:]
    images_pred_test = test['images_pred'][...][:,::2,:,:,:]

times_curr_train = np.load(os.path.join(data_folder,"times_curr_trainval.npy"),allow_pickle=True)
times_curr_test = np.load(os.path.join(data_folder,"times_curr_test.npy"),allow_pickle=True)
print('-'*50)
print("times_curr_train.shape:", times_curr_train.shape)
print("images_log_train.shape:", images_log_train.shape)
print("images_pred_train.shape:", images_pred_train.shape) 
print("times_curr_test.shape:", times_curr_test.shape)
print("images_log_test.shape:", images_log_test.shape)
print("images_pred_test.shape:", images_pred_test.shape)
print('-'*50)
# get the input dimension for constructing the model
num_log_frame = images_log_train.shape[1]
img_side_len = images_log_train.shape[2]
num_color_channel = images_log_train.shape[4]
num_pred_frame = images_pred_train.shape[1]
image_log_dim = [num_log_frame,img_side_len,img_side_len,num_color_channel]
image_pred_dim = [num_pred_frame,img_side_len,img_side_len,num_color_channel]

print("image side length:", img_side_len)
print("number of log frames:", num_log_frame)
print("number of pred frames:", num_pred_frame)
print("number of color channels:", num_color_channel)
print("context(log) image dimension:", image_log_dim)
print("future(pred) image dimension:", image_pred_dim)

In [None]:
with h5py.File(data_path,'r') as f:
    
    test = f['test']
    images_log_test = test['images_log'][...][:,::2,:,:,:]
    images_pred_test = test['images_pred'][...][:,::2,:,:,:]

times_curr_test = np.load(os.path.join(data_folder,"times_curr_test.npy"),allow_pickle=True)
print("times_curr_test.shape:", times_curr_test.shape)
print("images_log_test.shape:", images_log_test.shape)
print("images_pred_test.shape:", images_pred_test.shape)

In [None]:
# get the input dimension for constructing the model
num_log_frame = images_log_test.shape[1]
img_side_len = images_log_test.shape[2]
num_color_channel = images_log_test.shape[4]
num_pred_frame = images_pred_test.shape[1]
image_log_dim = [num_log_frame,img_side_len,img_side_len,num_color_channel]
image_pred_dim = [num_pred_frame,img_side_len,img_side_len,num_color_channel]

print("image side length:", img_side_len)
print("number of log frames:", num_log_frame)
print("number of pred frames:", num_pred_frame)
print("number of color channels:", num_color_channel)
print("context(log) image dimension:", image_log_dim)
print("future(pred) image dimension:", image_pred_dim)

### dataloader

In [None]:
class SkyImageDataset(data.Dataset):
    def __init__(self, data_set, transform=None):
        self.data_set = data_set
        self.transform = transform
        self.length = self.data_set[0].shape[0]
        
    def __getitem__(self, idx):
        input_data = self.data_set[0][idx]
        output_data = self.data_set[1][idx]
        length = len(input_data)
        input_data = input_data.transpose(0, 3, 1, 2)
        output_data = output_data.transpose(0, 3, 1, 2)
        
        output_data = torch.from_numpy(output_data / 255.0).contiguous().float()
        input_data = torch.from_numpy(input_data / 255.0).contiguous().float()
        # print()
        #print(input.size())
        #print(output.size())

        out = [idx,input_data,output_data]
        return out

    def __len__(self):
        return self.length


### Discriminator architecture

In [None]:
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of feature maps in discriminator
ndf = 16

# Frame discriminator 
class Discr_frame(torch.nn.Module):
    def __init__(self):
        super(Discr_frame, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.Dropout(p=0.3),
            nn.LeakyReLU(0.2),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.Dropout(p=0.3),
            nn.LeakyReLU(0.2),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, ndf * 16, 4, 1, 0, bias=False),
            nn.Flatten(),
            nn.Linear(ndf * 16, 1)
            #nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

### Training and Validate the Model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cwd = os.getcwd()
data_folder = os.path.join(cwd,"data") 
batch_size = 16
nepochs = 50
print_every = 1
eval_every = 5
plot_every = 1*eval_every
save_model_every = 1
save_name = 'PhyDNetGAN'
if not os.path.isdir('save/{}'.format(save_name)):
    os.mkdir('save/{}'.format(save_name))
lamda = 0.01 # weight for generator adversarial loss
training_discriminator_every=1

In [None]:
data_set = SkyImageDataset([images_log_train, images_pred_train])
train_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=True, num_workers=0)

data_set = SkyImageDataset([images_log_test, images_pred_test])
test_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
constraints = torch.zeros((49,7,7)).to(device)
ind = 0
for i in range(0,7):
    for j in range(0,7):
        constraints[ind,i,j] = 1
        ind +=1 

In [None]:
def plot_gen_images(predictions,times_curr,images_log,images_pred,select_idx):
    predictions = predictions.transpose((0,1,3,4,2))
    for i in range(len(select_idx)):
        #print("-"*50,"sample ",str(i+1), "-"*50)
        f, ax = plt.subplots(2,8)
        f.subplots_adjust(wspace=0, hspace=0)
        f.set_size_inches(24,6)
        ax[0,0].imshow(images_log[select_idx[i]][0][:,:,::-1])
        ax[0,0].set_title(times_curr[select_idx[i]]-datetime.timedelta(minutes=14))
        ax[0,1].imshow(images_log[select_idx[i]][2][:,:,::-1])
        ax[0,1].set_title(times_curr[select_idx[i]]-datetime.timedelta(minutes=10))
        ax[0,2].imshow(images_log[select_idx[i]][4][:,:,::-1])
        ax[0,2].set_title(times_curr[select_idx[i]]-datetime.timedelta(minutes=6))
        ax[0,3].imshow(images_log[select_idx[i]][7][:,:,::-1])
        ax[0,3].set_title(times_curr[select_idx[i]])
        ax[0,4].imshow(images_pred[select_idx[i]][0][:,:,::-1])
        ax[0,4].set_title(times_curr[select_idx[i]]+datetime.timedelta(minutes=1))
        ax[0,5].imshow(images_pred[select_idx[i]][2][:,:,::-1])
        ax[0,5].set_title(times_curr[select_idx[i]]+datetime.timedelta(minutes=5))
        ax[0,6].imshow(images_pred[select_idx[i]][4][:,:,::-1])
        ax[0,6].set_title(times_curr[select_idx[i]]+datetime.timedelta(minutes=9))
        ax[0,7].imshow(images_pred[select_idx[i]][7][:,:,::-1])
        ax[0,7].set_title(times_curr[select_idx[i]]+datetime.timedelta(minutes=15))

        ax[1,4].imshow(predictions[select_idx[i]][0][:,:,::-1])
        ax[1,5].imshow(predictions[select_idx[i]][2][:,:,::-1])
        ax[1,6].imshow(predictions[select_idx[i]][4][:,:,::-1])
        ax[1,7].imshow(predictions[select_idx[i]][7][:,:,::-1])

        ax[0,0].axis('off')
        ax[0,1].axis('off')
        ax[0,2].axis('off')
        ax[0,3].axis('off')
        ax[0,4].axis('off')
        ax[0,5].axis('off')
        ax[0,6].axis('off')
        ax[0,7].axis('off')
        ax[1,0].axis('off')
        ax[1,1].axis('off')
        ax[1,2].axis('off')
        ax[1,3].axis('off')
        ax[1,4].axis('off')
        ax[1,5].axis('off')
        ax[1,6].axis('off')
        ax[1,7].axis('off')

        plt.show()

In [None]:
sampling_step_1 = 15
sampling_step_2 = 30
r_exp_alpha = 2.5

def reserve_schedule_sampling_exp(epoch, log_length):
    real_input_flag_encoder = np.zeros(log_length, dtype=bool)
    if epoch < sampling_step_1:
        r_eta = 0.5
    elif epoch < sampling_step_2:
        r_eta = 1.0 - 0.5 * math.exp(-float(epoch - sampling_step_1) / r_exp_alpha)
    else:
        r_eta = 1.0
    for i in range(log_length):
        real_input_flag_encoder[i] = True if random.random() < r_eta else False
    return r_eta, real_input_flag_encoder

def schedule_sampling(epoch,pred_length):
    real_input_flag_decoder = np.zeros(pred_length, dtype=bool)
    if epoch < sampling_step_1:
        eta = 0.5
    elif epoch < sampling_step_2:
        eta = 0.5 - (0.5 / (sampling_step_2 - sampling_step_1)) * (epoch - sampling_step_1)
    else:
        eta = 0
    for i in range(pred_length):
        real_input_flag_decoder[i] = True if random.random() < eta else False
    return eta, real_input_flag_decoder

In [None]:
r_eta = np.zeros(nepochs)
eta = np.zeros(nepochs)
for epoch in range(nepochs):
    r_eta[epoch],real_input_flag_encoder = reserve_schedule_sampling_exp(epoch,num_log_frame)
    eta[epoch],real_input_flag_decoder = schedule_sampling(epoch,num_pred_frame)
plt.plot(range(nepochs),r_eta,label='r_eta')
plt.plot(range(nepochs),eta,label='eta')
plt.legend()

In [None]:
def train_on_batch(epoch, input_tensor, target_tensor, encoder, encoder_optimizer, criterion_mae, criterion_mse, discr_frame, discr_frame_optimizer):                
    
    encoder_optimizer.zero_grad()
    
    # input_tensor : torch.Size([batch_size, input_length, channels, cols, rows])
    curr_batch_size = input_tensor.size(0)    
    input_length  = input_tensor.size(1)
    target_length = target_tensor.size(1)
    loss = 0
    encoder_frame_ad_loss = 0
    real_label = torch.full((curr_batch_size,), 1, dtype=torch.float, device=device)
    fake_label = torch.full((curr_batch_size,), 1, dtype=torch.float, device=device)
    discr_frame_loss = 0
    r_eta,real_input_flag_encoder = reserve_schedule_sampling_exp(epoch,input_length)
    eta,real_input_flag_decoder = schedule_sampling(epoch,target_length)
    
    encoder_input = input_tensor[:,0,:,:,:]
    for ei in range(input_length-1): 
        encoder_output, encoder_hidden, encoder_output_image,_,_ = encoder(encoder_input, (ei==0))
        encoder_target = input_tensor[:,ei+1,:,:,:]
        loss += criterion_mae(encoder_output_image,encoder_target)
        
        if real_input_flag_encoder[ei]:
            encoder_input = encoder_target # Teacher forcing    
        else:
            encoder_input = encoder_output_image
    
    if real_input_flag_encoder[-1]:        
        decoder_input = input_tensor[:,-1,:,:,:] 
    else:
        decoder_input = encoder_output_image # first decoder input = last image of input sequence
    
    for di in range(target_length):
        decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input)
        target = target_tensor[:,di,:,:,:]
        loss += criterion_mae(output_image,target)
        if (epoch+1)%training_discriminator_every==0:
            discr_frame_out_fake = discr_frame(output_image).view(-1)
            encoder_frame_ad_loss += 0.5 * torch.mean((discr_frame_out_fake - 1)**2)
        
        if real_input_flag_decoder[di]:
            decoder_input = target # Teacher forcing    
        else:
            decoder_input = output_image
    
    if (epoch+1)%training_discriminator_every==0:
        loss += lamda*encoder_frame_ad_loss
    
    # Moment regularization  # encoder.phycell.cell_list[0].F.conv1.weight # size (nb_filters,in_channels,7,7)
    k2m = K2M([7,7]).to(device)
    for b in range(0,encoder.phycell.cell_list[0].input_dim):
        filters = encoder.phycell.cell_list[0].F.conv1.weight[:,b,:,:] # (nb_filters,7,7)     
        m = k2m(filters.double()) 
        m  = m.float()   
        loss += criterion_mse(m, constraints) # constrains is a precomputed matrix   
    loss.backward()
    encoder_optimizer.step()
    
    if (epoch+1)%training_discriminator_every==0:
        discr_frame_optimizer.zero_grad()
        if real_input_flag_encoder[-1]:        
            decoder_input = input_tensor[:,-1,:,:,:] 
        else:
            decoder_input = encoder_output_image # first decoder input = last image of input sequence
        for di in range(target_length):
            decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input)
            target = target_tensor[:,di,:,:,:]
            discr_frame_out_real = discr_frame(target).view(-1)
            discr_frame_out_fake = discr_frame(output_image.detach()).view(-1)
            discr_frame_loss += 0.5 * (torch.mean((discr_frame_out_real - 1)**2) + torch.mean(discr_frame_out_fake**2))

            if real_input_flag_decoder[di]:
                decoder_input = target # Teacher forcing    
            else:
                decoder_input = output_image

        discr_frame_loss.backward()
        discr_frame_optimizer.step()
        
        return discr_frame_loss.item()/target_length, encoder_frame_ad_loss.item()/target_length, loss.item() / target_length
    
    else:
        return discr_frame_loss/target_length, encoder_frame_ad_loss/target_length, loss.item() / target_length

def trainIters(encoder, discr_frame, nepochs, print_every=10,eval_every=10,name=''):
    encoder_total_train_losses = []
    encoder_ad_train_losses = []
    discr_train_losses = []
    best_mse = float('inf')

    encoder_optimizer = torch.optim.Adam(encoder.parameters(),lr=0.001,betas=(0.5,0.99))
    discr_frame_optimizer = torch.optim.Adam(discr_frame.parameters(),lr=0.0002,betas=(0.5,0.99))
    scheduler_enc = ReduceLROnPlateau(encoder_optimizer, mode='min', patience=5, factor=0.1, verbose=True)
    scheduler_discr = ReduceLROnPlateau(discr_frame_optimizer, mode='min', patience=5, factor=0.1, verbose=True)
    criterion_mae = nn.L1Loss()
    criterion_mse = nn.MSELoss()
    
    for epoch in range(0, nepochs):
        t0 = time.time()
        encoder_total_loss_epoch = 0
        discr_loss_epoch = 0
        encoder_ad_loss_epoch = 0
        
        for i, out in enumerate(train_loader, 0):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            discr_frame_loss, encoder_frame_ad_loss, loss = train_on_batch(epoch, input_tensor, target_tensor, encoder, encoder_optimizer, criterion_mae, criterion_mse, discr_frame, discr_frame_optimizer)                                   
            encoder_total_loss_epoch += loss
            discr_loss_epoch += discr_frame_loss
            encoder_ad_loss_epoch += encoder_frame_ad_loss
        
        encoder_total_train_losses.append(encoder_total_loss_epoch)     
        encoder_ad_train_losses.append(encoder_ad_loss_epoch)
        discr_train_losses.append(discr_loss_epoch)
        
        
        if (epoch+1) % print_every == 0:
            print('training epoch {0}/{1}'.format(epoch+1,nepochs))
            print('encoder total loss:{0:.3f}'.format(encoder_total_loss_epoch))
            print('time epoch:{0:.3f}s'.format(time.time()-t0))
            
        if (epoch+1) % save_model_every == 0:
            print('saving the model...')
            torch.save(encoder.state_dict(),'save/{0}/encoder.pth'.format(name))
            torch.save(discr_frame.state_dict(),'save/{0}/discriminator.pth'.format(name))
            
        if (epoch+1) % training_discriminator_every == 0:  
            print('encoder adversarial loss:{0:.3f}'.format(encoder_ad_loss_epoch))
            print('discriminator loss:{0:.3f}'.format(discr_loss_epoch)) 
            f,ax=plt.subplots()
            ax.plot(range(len(encoder_ad_train_losses)),encoder_ad_train_losses,label="gen_loss")
            ax.plot(range(len(discr_train_losses)),discr_train_losses,label='disc_loss')
            ax.set_xlabel('epoch')
            ax.set_ylabel('loss')
            ax.legend()
            f.tight_layout()
            plt.show()
            
        if (epoch+1) % eval_every == 0:
            mse,mae,predictions,_ = evaluate(encoder,test_loader)
            scheduler_enc.step(mae)
            scheduler_discr.step(mae)
        
        if (epoch+1) % plot_every == 0:
            select_idx = [3445]
            plot_gen_images(predictions,times_curr_test,images_log_test,images_pred_test,select_idx)
        
    return encoder_total_train_losses,encoder_ad_train_losses,discr_train_losses

def evaluate(encoder,loader):
    total_mse, total_mae = 0,0
    t0 = time.time()
    predictions = []
    indices = []
    with torch.no_grad():
        for i, out in enumerate(loader, 0):
            indices.append(out[0])
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            input_length = input_tensor.size()[1]
            target_length = target_tensor.size()[1]

            for ei in range(input_length-1):
                encoder_output, encoder_hidden, _,_,_  = encoder(input_tensor[:,ei,:,:,:], (ei==0))

            decoder_input = input_tensor[:,-1,:,:,:] # first decoder input= last image of input sequence
            prediction = []
            
            for di in range(target_length):
                decoder_output, decoder_hidden, output_image,_,_ = encoder(decoder_input, False, False)
                decoder_input = output_image
                prediction.append(output_image.cpu())
            
            input = input_tensor.cpu().numpy()
            target = target_tensor.cpu().numpy()
            prediction =  np.stack(prediction) # (8, batch_size, 3, 64, 64)
            prediction = prediction.swapaxes(0,1)  # (batch_size, 8, 3, 64, 64)
            
            
            mse_batch = np.mean((prediction-target)**2 , axis=1).sum()
            mae_batch = np.mean(np.abs(prediction-target) ,  axis=1).sum() 
            total_mse += mse_batch
            total_mae += mae_batch
            
            predictions.append(prediction)
    
    predictions =  np.concatenate(predictions,axis=0) # (10, batch_size, 1, 64, 64)
    print("validation...")    
    print('mse per frame:{0:.3f}'.format(total_mse/len(times_curr_test)))  
    print('mae per frame:{0:.3f}'.format(total_mae/len(times_curr_test)))
    print('mse per pixel:{0:.3f}'.format(total_mse/len(times_curr_test)/(img_side_len*img_side_len*num_color_channel)))  
    print('mae per pixel:{0:.3f}'.format(total_mae/len(times_curr_test)/(img_side_len*img_side_len*num_color_channel)))
    print('time:{0:.3f}s'.format(time.time()-t0))
    print('-'*40)
    return total_mse/len(times_curr_test),  total_mae/len(times_curr_test), predictions, indices

In [None]:
phycell  =  PhyCell(input_shape=(16,16), input_dim=64, F_hidden_dims=[49], n_layers=1, kernel_size=(7,7), device=device) 
convcell =  ConvLSTM(input_shape=(16,16), input_dim=64, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device)   
encoder  = EncoderRNN(phycell, convcell, device)
discriminator_frame = Discr_frame().to(device)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
   
print('phycell ' , count_parameters(phycell))    
print('convcell ' , count_parameters(convcell)) 
print('encoder ' , count_parameters(encoder)) 
print('discriminator_frame ', count_parameters(discriminator_frame))

In [None]:
encoder_total_train_losses,encoder_ad_train_losses,discr_train_losses = trainIters(encoder,discriminator_frame,nepochs,print_every=print_every,eval_every=eval_every,name=save_name)

### Save Predicted Images from Validation Set

In [None]:
data_set = SkyImageDataset([images_log_test, images_pred_test])
test_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
encoder.load_state_dict(torch.load('save/{0}/encoder.pth'.format(save_name)))
encoder.eval()
mse, mae, predictions, indices = evaluate(encoder,test_loader)

In [None]:
print(predictions.shape)
predictions = predictions.transpose((0,1,3,4,2))
print(predictions.shape)

In [None]:
np.save('save/{0}/predicted_images.npy'.format(save_name), predictions)

### Visualiz Some Sample Predictions

In [None]:
random.seed(0)
select_num_samples = 30
select_idx = random.sample(np.arange(len(times_curr_test)).tolist(),select_num_samples)

In [None]:
for i in range(select_num_samples):
    print("-"*50,"sample ",str(i+1), "-"*50)
    f, ax = plt.subplots(2,8)
    f.set_size_inches(24,6)
    ax[0,0].imshow(images_log_test[select_idx[i]][0][:,:,::-1])
    ax[0,0].set_title(times_curr_test[select_idx[i]]-datetime.timedelta(minutes=15))
    ax[0,1].imshow(images_log_test[select_idx[i]][2][:,:,::-1])
    ax[0,1].set_title(times_curr_test[select_idx[i]]-datetime.timedelta(minutes=11))
    ax[0,2].imshow(images_log_test[select_idx[i]][4][:,:,::-1])
    ax[0,2].set_title(times_curr_test[select_idx[i]]-datetime.timedelta(minutes=7))
    ax[0,3].imshow(images_log_test[select_idx[i]][7][:,:,::-1])
    ax[0,3].set_title(times_curr_test[select_idx[i]]-datetime.timedelta(minutes=1))
    ax[0,4].imshow(images_pred_test[select_idx[i]][0][:,:,::-1])
    ax[0,4].set_title(times_curr_test[select_idx[i]]+datetime.timedelta(minutes=1))
    ax[0,5].imshow(images_pred_test[select_idx[i]][2][:,:,::-1])
    ax[0,5].set_title(times_curr_test[select_idx[i]]+datetime.timedelta(minutes=5))
    ax[0,6].imshow(images_pred_test[select_idx[i]][4][:,:,::-1])
    ax[0,6].set_title(times_curr_test[select_idx[i]]+datetime.timedelta(minutes=9))
    ax[0,7].imshow(images_pred_test[select_idx[i]][7][:,:,::-1])
    ax[0,7].set_title(times_curr_test[select_idx[i]]+datetime.timedelta(minutes=15))
    
    ax[1,4].imshow(predictions[select_idx[i]][0][:,:,::-1])
    ax[1,5].imshow(predictions[select_idx[i]][2][:,:,::-1])
    ax[1,6].imshow(predictions[select_idx[i]][4][:,:,::-1])
    ax[1,7].imshow(predictions[select_idx[i]][7][:,:,::-1])
    
    ax[0,0].axis('off')
    ax[0,1].axis('off')
    ax[0,2].axis('off')
    ax[0,3].axis('off')
    ax[0,4].axis('off')
    ax[0,5].axis('off')
    ax[0,6].axis('off')
    ax[0,7].axis('off')
    ax[1,0].axis('off')
    ax[1,1].axis('off')
    ax[1,2].axis('off')
    ax[1,3].axis('off')
    ax[1,4].axis('off')
    ax[1,5].axis('off')
    ax[1,6].axis('off')
    ax[1,7].axis('off')
    
    plt.show()