In [None]:
from data_factory import *
from model_factory.models import *
from eval import *
import matplotlib.pyplot as plt


In [None]:
# Prepare test data
file_address = "data_factory/data/moving-mnist-test.npz"
test_dataset = MMDataset(file_address, size = 32)

In [None]:
# Prepare models
# Load weights

model_Baseline = Baseline().cuda()
model_Baseline.load_state_dict(torch.load("model_factory/state_dict/Baseline.tar"))

model_BaselineWide4Deep4 = BaselineWide4Deep4().cuda()
model_BaselineWide4Deep4.load_state_dict(torch.load("model_factory/state_dict/BaselineWide4Deep4.tar"))

model_BaselineWide4Deep4Skip = BaselineWide4Deep4Skip().cuda()
model_BaselineWide4Deep4Skip.load_state_dict(torch.load("model_factory/state_dict/BaselineWide4Deep4Skip.tar"))

model_BaselineWide4Deep3SkipAutoreg = BaselineWide4Deep3SkipAutoreg().cuda()
model_BaselineWide4Deep3SkipAutoreg.load_state_dict(torch.load("model_factory/state_dict/BaselineWide4Deep3SkipAutoreg_c.tar"))

In [None]:
# Compute and plot SSIM per frame for each model

ssim_Baseline = evaluate_ssim(model_Baseline, test_dataset)
ssim_BaselineWide4Deep4 = evaluate_ssim(model_BaselineWide4Deep4, test_dataset)
ssim_BaselineWide4Deep4Skip = evaluate_ssim(model_BaselineWide4Deep4Skip, test_dataset)
ssim_BaselineWide4Deep3SkipAutoreg = evaluate_ssim(model_BaselineWide4Deep3SkipAutoreg, test_dataset)

x_axis = np.arange(1, 11)
plt.plot(x_axis, ssim_BaselineWide4Deep3SkipAutoreg, label = "BaselineWide4Deep3SkipAutoreg")
plt.plot(x_axis, ssim_BaselineWide4Deep4Skip, label = "BaselineWide4Deep4Skip")
plt.plot(x_axis, ssim_BaselineWide4Deep4, label = "BaselineWide4Deep4")
plt.plot(x_axis, ssim_Baseline, label = "Baseline")
plt.legend()
plt.title("SSIM Per Frame")
plt.ylabel("SSIM")
plt.xlabel("Target Frames")
plt.xticks(range(1,10+1))
plt.savefig("plots/ssim.png")

In [None]:
# Compute and plot PSNR per frame for each model

psnr_Baseline = evaluate_psnr(model_Baseline, test_dataset)
psnr_BaselineWide4Deep4 = evaluate_psnr(model_BaselineWide4Deep4, test_dataset)
psnr_BaselineWide4Deep4Skip = evaluate_psnr(model_BaselineWide4Deep4Skip, test_dataset)
psnr_BaselineWide4Deep3SkipAutoreg = evaluate_psnr(model_BaselineWide4Deep3SkipAutoreg, test_dataset)

x_axis = np.arange(1, 11)
plt.plot(x_axis, psnr_BaselineWide4Deep3SkipAutoreg, label = "BaselineWide4Deep3SkipAutoreg")
plt.plot(x_axis, psnr_BaselineWide4Deep4Skip, label = "BaselineWide4Deep4Skip")
plt.plot(x_axis, psnr_BaselineWide4Deep4, label = "BaselineWide4Deep4")
plt.plot(x_axis, psnr_Baseline, label = "Baseline")
plt.legend()
plt.title("PSNR Per Frame")
plt.ylabel("PSNR")
plt.xlabel("Target Frames")
plt.xticks(range(1,10+1))
plt.savefig("plots/psnr.png")

In [None]:
# Show sample results for Baseline model

show_result(model_Baseline, test_dataset, index = 25)
plt.savefig("plots/Baseline_0.png")

show_result(model_Baseline, test_dataset, index = 40)
plt.savefig("plots/Baseline_1.png")

show_result(model_Baseline, test_dataset, index = 170)
plt.savefig("plots/Baseline_2.png")

In [None]:
# Show sample results for BaselineWide4Deep3SkipAutoreg

show_result(model_BaselineWide4Deep3SkipAutoreg, test_dataset, index = 25)
plt.savefig("plots/BaselineWide4Deep3SkipAutoreg_0.png")

show_result(model_BaselineWide4Deep3SkipAutoreg, test_dataset, index = 40)
plt.savefig("plots/BaselineWide4Deep3SkipAutoreg_1.png")

show_result(model_BaselineWide4Deep3SkipAutoreg, test_dataset, index = 170)
plt.savefig("plots/BaselineWide4Deep3SkipAutoreg_2.png")

show_result(model_BaselineWide4Deep3SkipAutoreg, test_dataset, index = 180)
plt.savefig("plots/BaselineWide4Deep3SkipAutoreg_3.png")

show_result(model_BaselineWide4Deep3SkipAutoreg, test_dataset, index = 195)
plt.savefig("plots/BaselineWide4Deep3SkipAutoreg_4.png")

show_result(model_BaselineWide4Deep3SkipAutoreg, test_dataset, index = 210)
plt.savefig("plots/BaselineWide4Deep3SkipAutoreg_5.png")