In [1]:
import torch
torch.manual_seed(0)
import tonic
import tonic.transforms
import torchvision

In [2]:
PATH = 'experiment/experiment-09'
import sys
sys.path.append(PATH)

In [3]:
n = 4000
time_jitter_std = 10000

In [4]:
sensor_size = tonic.datasets.NMNIST.sensor_size

jitter_transform = tonic.transforms.Compose([
            tonic.transforms.Denoise(filter_time=10000),
            tonic.transforms.TimeJitter(std=time_jitter_std, clip_negative=True),
            tonic.transforms.UniformNoise(sensor_size=sensor_size, n=n),
            tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=30),
            torch.from_numpy,
            torchvision.transforms.Lambda(lambda x: x.to(torch.float32))
        ])
original_transform = tonic.transforms.Compose([
            tonic.transforms.Denoise(filter_time=10000),
            tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=30),
            torch.from_numpy,
            torchvision.transforms.Lambda(lambda x: x.to(torch.float32))
        ])

In [5]:
jitter_dataset = tonic.datasets.NMNIST(save_to='/DATA/hwkang', train=False, transform=jitter_transform)
original_dataset = tonic.datasets.NMNIST(save_to='/DATA/hwkang', train=False, transform=original_transform)

In [6]:
from torch.utils.data import DataLoader
import multiprocessing

jitter_loader = DataLoader(jitter_dataset, batch_size=100, num_workers=multiprocessing.cpu_count() // 2, shuffle=False)
original_loader = DataLoader(original_dataset, batch_size=100, num_workers=multiprocessing.cpu_count() // 2, shuffle=False)

In [7]:
from helper import calculate_psnr

jitter_x, _ = next(iter(jitter_loader))
original_x, _ = next(iter(original_loader))

In [8]:
list_psnr = []
for x, y in zip(jitter_x, original_x):
    psnr = calculate_psnr(x, y)
    list_psnr.append(psnr)

In [None]:
import math
filtered_psnr = [x for x in list_psnr if not math.isnan(x) and not math.isinf(x)]
average_psnr = sum(filtered_psnr)/len(filtered_psnr)
print(average_psnr)

In [10]:
jitter_frames = jitter_dataset[0][0]
original_frames = original_dataset[0][0]

In [None]:
import matplotlib.pyplot as plt
from IPython.display import HTML

ani = tonic.utils.plot_animation(jitter_frames)
HTML(ani.to_jshtml())

In [None]:
ani = tonic.utils.plot_animation(original_frames)
HTML(ani.to_jshtml())