In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import random
import time
from models import ConvLSTM, EncoderRNN
import os
import h5py
import torch.utils.data as data
import math
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import datetime
import itertools

In [None]:
# check tensorflow version
print("pytorch version:", torch.__version__)
# check available gpu
gpus =  torch.cuda.get_device_name(torch.cuda.current_device())
print("available gpus:", gpus)

In [None]:
cwd = os.getcwd()
pardir = os.path.dirname(os.path.dirname(cwd))
data_folder = os.path.join(pardir,'data')
data_path = os.path.join(data_folder,'video_prediction_dataset.hdf5')
model_name = 'ConvLSTM'
output_folder = os.path.join(cwd,"save", model_name)
if not os.path.isdir(output_folder):
    os.makedirs(output_folder)
print("data_folder:", data_folder)
print("data_path:", data_path)
print("output_folder:", output_folder)

In [None]:
# generate handler for the hdf5 data
forecast_dataset = h5py.File(data_path, 'r')

# show structure of the hdf5 data
def get_all(name):
    if name!=None:
        print(forecast_dataset[name])
    
forecast_dataset.visit(get_all)

forecast_dataset.close()

In [None]:
with h5py.File(data_path,'r') as f:
    trainval = f['trainval']
    images_log_train = trainval['images_log'][...][:,::2,:,:,:]
    images_pred_train = trainval['images_pred'][...][:,::2,:,:,:]
    
    test = f['test']
    images_log_test = test['images_log'][...][:,::2,:,:,:]
    images_pred_test = test['images_pred'][...][:,::2,:,:,:]

times_curr_train = np.load(os.path.join(data_folder,"times_curr_trainval.npy"),allow_pickle=True)
times_curr_test = np.load(os.path.join(data_folder,"times_curr_test.npy"),allow_pickle=True)

print('-'*50)
print("times_curr_train.shape:", times_curr_train.shape)
print("images_log_train.shape:", images_log_train.shape)
print("images_pred_train.shape:", images_pred_train.shape) 
print("times_curr_test.shape:", times_curr_test.shape)
print("images_log_test.shape:", images_log_test.shape)
print("images_pred_test.shape:", images_pred_test.shape)
print('-'*50)

# get the input dimension for constructing the model
num_log_frame = images_log_train.shape[1]
img_side_len = images_log_train.shape[2]
num_color_channel = images_log_train.shape[4]
num_pred_frame = images_pred_train.shape[1]
image_log_dim = [num_log_frame,img_side_len,img_side_len,num_color_channel]
image_pred_dim = [num_pred_frame,img_side_len,img_side_len,num_color_channel]

print("image side length:", img_side_len)
print("number of log frames:", num_log_frame)
print("number of pred frames:", num_pred_frame)
print("number of color channels:", num_color_channel)
print("context(log) image dimension:", image_log_dim)
print("future(pred) image dimension:", image_pred_dim)

### dataloader

In [None]:
class SkyImageDataset(data.Dataset):
    def __init__(self, data_set, transform=None):
        self.data_set = data_set
        self.transform = transform
        self.length = self.data_set[0].shape[0]
        
    def __getitem__(self, idx):
        input_data = self.data_set[0][idx]
        output_data = self.data_set[1][idx]
        length = len(input_data)
        input_data = input_data.transpose(0, 3, 1, 2)
        output_data = output_data.transpose(0, 3, 1, 2)
        
        output_data = torch.from_numpy(output_data / 255.0).contiguous().float()
        input_data = torch.from_numpy(input_data / 255.0).contiguous().float()

        out = [idx,input_data,output_data]
        return out

    def __len__(self):
        return self.length


### Training and Validate the Model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
batch_size_train = 16
batch_size_test = 64
nepochs = 50
print('nepochs:',nepochs)
print_every = 1
eval_every = 5

In [None]:
data_set = SkyImageDataset([images_log_train, images_pred_train])
train_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size_train, shuffle=True, num_workers=0)

data_set = SkyImageDataset([images_log_test, images_pred_test])
test_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size_test, shuffle=False, num_workers=0)

In [None]:
def train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion,teacher_forcing_ratio):                
    encoder_optimizer.zero_grad()
    input_length  = input_tensor.size(1)
    target_length = target_tensor.size(1)
    loss = 0
    for ei in range(input_length-1): 
        output_image = encoder(input_tensor[:,ei,:,:,:], (ei==0) )
        loss += criterion(output_image,input_tensor[:,ei+1,:,:,:])

    decoder_input = input_tensor[:,-1,:,:,:] # first decoder input = last image of input sequence
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 
    for di in range(target_length):
        output_image = encoder(decoder_input)
        target = target_tensor[:,di,:,:,:]
        loss += criterion(output_image,target)
        if use_teacher_forcing:
            decoder_input = target # Teacher forcing    
        else:
            decoder_input = output_image

    loss.backward()
    encoder_optimizer.step()
    return loss.item() / target_length

def trainIters(encoder, nepochs, print_every=10,eval_every=10,name=''):
    train_losses = []
    best_mse = float('inf')

    encoder_optimizer = torch.optim.Adam(encoder.parameters(),lr=0.001)
    scheduler_enc = ReduceLROnPlateau(encoder_optimizer, mode='min', patience=2,factor=0.1,verbose=True)
    criterion = nn.MSELoss()
    
    for epoch in range(0, nepochs):
        t0 = time.time()
        loss_epoch = 0
        teacher_forcing_ratio = np.maximum(0 , 1 - epoch * 0.03) 
        
        for i, out in enumerate(train_loader, 0):
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            loss = train_on_batch(input_tensor, target_tensor, encoder, encoder_optimizer, criterion, teacher_forcing_ratio)                                   
            loss_epoch += loss
                      
        train_losses.append(loss_epoch)        
        if (epoch+1) % print_every == 0:
            print('-'*50)
            print('epoch ',epoch,  ' loss ',loss_epoch, ' time epoch ',time.time()-t0)
            print("saving model...")
            torch.save(encoder.state_dict(),'save/{0}/encoder.pth'.format(name)) 
            
        if (epoch+1) % eval_every == 0:
            mse_per_frame,mae_per_frame,mse_per_pixel,mae_per_pixel,_,_ = evaluate(encoder,test_loader) 
            scheduler_enc.step(mse_per_frame)
             
            
    return train_losses

def evaluate(encoder,loader):
    print("validation start...")
    total_mse, total_mae = 0,0
    t0 = time.time()
    predictions = []
    indices = []
    num_val_samples = len(times_curr_test)
    with torch.no_grad():
        for i, out in enumerate(loader, 0):
            indices.append(out[0])
            input_tensor = out[1].to(device)
            target_tensor = out[2].to(device)
            input_length = input_tensor.size()[1]
            target_length = target_tensor.size()[1]

            for ei in range(input_length-1):
                _  = encoder(input_tensor[:,ei,:,:,:], (ei==0))

            decoder_input = input_tensor[:,-1,:,:,:] # first decoder input= last image of input sequence
            prediction = []
            
            for di in range(target_length):
                output_image = encoder(decoder_input, False)
                decoder_input = output_image
                prediction.append(output_image.cpu())
            
            input = input_tensor.cpu().numpy()
            target = target_tensor.cpu().numpy()
            prediction =  np.stack(prediction) # (8, batch_size, 3, 64, 64)
            prediction = prediction.swapaxes(0,1)  # (batch_size, 8, 3, 64, 64)
            
            mse_batch = np.mean((prediction-target)**2, axis=1).sum()
            mae_batch = np.mean(np.abs(prediction-target),  axis=1).sum() 
            total_mse += mse_batch
            total_mae += mae_batch
            
            predictions.append(prediction)
    
    predictions =  np.concatenate(predictions,axis=0) # (batch_size, 8, 3, 64, 64)
    total_mse_per_frame = total_mse/num_val_samples
    total_mae_per_frame = total_mae/num_val_samples
    total_mse_per_pixel = total_mse_per_frame/(img_side_len*img_side_len*num_color_channel)
    total_mae_per_pixel = total_mae_per_frame/(img_side_len*img_side_len*num_color_channel)
    
    print('eval mse per frame:',total_mse_per_frame)
    print('eval mae per frame:', total_mae_per_frame) 
    print('eval mse per pixel:',total_mse_per_pixel) 
    print('eval mae per pixel:', total_mae_per_pixel) 
    print('time:', time.time()-t0)        
    
    return total_mse_per_frame,  total_mae_per_frame, total_mse_per_pixel, total_mae_per_pixel, predictions, indices

In [None]:
convcell = ConvLSTM(input_shape=(64,64), input_dim=3, hidden_dims=[128,128,64], n_layers=3, kernel_size=(3,3), device=device) 
encoder = EncoderRNN(convcell, device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
   
print('convcell ' , count_parameters(convcell) ) 
print('encoder ' , count_parameters(encoder) ) 

In [None]:
train_loss = trainIters(encoder,nepochs,print_every=print_every,eval_every=eval_every,name=model_name)

### Save Predicted Images from Validation Set

In [None]:
encoder.load_state_dict(torch.load('save/{0}/encoder.pth'.format(model_name)))
encoder.eval()
mse_per_frame, mae_per_frame, mse_per_pixel, mae_per_pixel, predictions, indices = evaluate(encoder,test_loader)

In [None]:
print(predictions.shape)
predictions = predictions.transpose((0,1,3,4,2))
print(predictions.shape)

In [None]:
np.save('save/{0}/predicted_images.npy'.format(model_name), predictions)

### Visualize Some Sample Predictions

In [None]:
random.seed(0)
select_num_samples = 30
select_idx = random.sample(np.arange(len(predictions)).tolist(),select_num_samples)

In [None]:
for i in range(select_num_samples):
    print("-"*50,"sample ",str(i+1), "-"*50)
    f, ax = plt.subplots(2,8)
    f.set_size_inches(24,6)
    ax[0,0].imshow(images_log_test[select_idx[i]][0][:,:,::-1])
    ax[0,0].set_title(times_curr_test[select_idx[i]]-datetime.timedelta(minutes=15))
    ax[0,1].imshow(images_log_test[select_idx[i]][2][:,:,::-1])
    ax[0,1].set_title(times_curr_test[select_idx[i]]-datetime.timedelta(minutes=11))
    ax[0,2].imshow(images_log_test[select_idx[i]][4][:,:,::-1])
    ax[0,2].set_title(times_curr_test[select_idx[i]]-datetime.timedelta(minutes=7))
    ax[0,3].imshow(images_log_test[select_idx[i]][7][:,:,::-1])
    ax[0,3].set_title(times_curr_test[select_idx[i]]-datetime.timedelta(minutes=1))
    ax[0,4].imshow(images_pred_test[select_idx[i]][0][:,:,::-1])
    ax[0,4].set_title(times_curr_test[select_idx[i]]+datetime.timedelta(minutes=1))
    ax[0,5].imshow(images_pred_test[select_idx[i]][2][:,:,::-1])
    ax[0,5].set_title(times_curr_test[select_idx[i]]+datetime.timedelta(minutes=5))
    ax[0,6].imshow(images_pred_test[select_idx[i]][4][:,:,::-1])
    ax[0,6].set_title(times_curr_test[select_idx[i]]+datetime.timedelta(minutes=9))
    ax[0,7].imshow(images_pred_test[select_idx[i]][7][:,:,::-1])
    ax[0,7].set_title(times_curr_test[select_idx[i]]+datetime.timedelta(minutes=15))
    
    ax[1,4].imshow(predictions[select_idx[i]][0][:,:,::-1])
    ax[1,5].imshow(predictions[select_idx[i]][2][:,:,::-1])
    ax[1,6].imshow(predictions[select_idx[i]][4][:,:,::-1])
    ax[1,7].imshow(predictions[select_idx[i]][7][:,:,::-1])
    
    ax[0,0].axis('off')
    ax[0,1].axis('off')
    ax[0,2].axis('off')
    ax[0,3].axis('off')
    ax[0,4].axis('off')
    ax[0,5].axis('off')
    ax[0,6].axis('off')
    ax[0,7].axis('off')
    ax[1,0].axis('off')
    ax[1,1].axis('off')
    ax[1,2].axis('off')
    ax[1,3].axis('off')
    ax[1,4].axis('off')
    ax[1,5].axis('off')
    ax[1,6].axis('off')
    ax[1,7].axis('off')
    
    plt.show()