In [None]:
from data_factory import *
from model_factory.models import *
from eval import *
import matplotlib.pyplot as plt

In [None]:
## Identify test dataset location
# file_address is the address of the test data file as provided by (Wang, Y., et al., 2019)
file_address = "data_factory/data/moving-mnist-test.npz"
# Note: train and valid splits can also be obtained by replacing "test" with "train" and "valid", respectively, in the file_address

## Create test dataset
# MMDataset creates a PyTorch Dataset object for the data file 
# size is the width of each frame---I used 64 for one model and 32 for the rest
test_dataset = MMDataset(file_address, size = 32)

In [None]:
## Load a model
# In this example we load the Baseline model. However, any other model can be loaded simply by its model name
# The model names simply appear as file names in model_factory/models/*.py
model = Baseline().cuda()

## Load saved weights---the evaluations presented in the blog are based on these weights
# See model_factory/state_dict/*.tar for a list of saved weights
model.load_state_dict(torch.load("model_factory/state_dict/Baseline.tar"))

## Evaluate the MSE loss of the model on the test dataset
# Make sure that the frame size for test_dataset matches the model's frame size
evaluate_mse(model, test_dataset, batch_size = 64)

In [None]:
## SSIM
# Evaluate SSIM per frame for model on the test dataset
# Output is a numpy array with one component per target frame
ssim = evaluate_ssim(model, test_dataset)


In [None]:
## PSNR
# Evaluate PNSR per frame for model on the test dataset
# Output is a numpy array with one component per target frame
psnr = evaluate_psnr(model, test_dataset)


In [None]:
## Qualitative results
# Display qualitative results for model on the test dataset
# index : index for the test data, must be in range 0 <= index < len(test_dataset)
show_result(model, test_dataset, index = 0)