In [1]:
import numpy as np
import tonic
import tonic.transforms as transforms
from torchvision import transforms as tt
import pandas as pd
from torch.utils.data import Dataset, DataLoader
#from tonic.dataset import Dataset
from typing import Callable, Optional
import torch
from torch import nn
import matplotlib.pyplot as plt
from torchvision.ops import masks_to_boxes
from torchvision.utils import draw_bounding_boxes
from typing import Tuple
from tqdm.notebook import tqdm
from statistics import mean

import snntorch as snn
from snntorch import utils
from snntorch import functional as SF

In [2]:
# Don't change unless also changed in EVIMO saving
num_bins_per_frame = 8 
framerate = 200

# Standardized sizes, from EVIMO recording
sensor_size = [640, 480, 2]
input_size=(480, 640)

beta = 0.9

batch_size = 4
num_epochs = 1

num_classes = 25

output_size= (480, 640) #(30, 40) # Can be changed

dtype=torch.float
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
device

device(type='cpu')

In [3]:
class EVIMOMask(Dataset):
    def __init__(self,
                 dirs: list,
                 num_bins_per_frame: int,
                 output_size: Tuple,
                ):
        self.dirs = dirs
        self.num_bins_per_frame = num_bins_per_frame
        self.output_size = output_size

        self.length = 0
        self.lengths = []
        for dir in self.dirs:
            curr_len = np.load(dir + "/length.npy")
            self.length += curr_len
            self.lengths.append(curr_len)

    def find_num_classes():
        self.num_classes = 0
        self.classes = torch.empty((1, ))
        for idx in range(0, len(self)):
            dir, index = self.get_dir_index(idx)

            item = np.load(dir + "/" + str(index) + ".npy", allow_pickle=True).tolist()
            mask = torch.from_numpy(np.asarray([item["mask"]])).to(torch.int64)

            classes = torch.unique(mask, sorted=False)
            self.classes = torch.cat((self.classes, classes), dim=0)

        self.classes = torch.unique(self.classes)
        self.num_classes = int(self.classes.max()) + 1
            

    def get_dir_index(self, index):
        curr_idx_sum = 0
        for i, length in enumerate(self.lengths):
            #print(index, length, curr_idx_sum)
            if curr_idx_sum <= index < curr_idx_sum + length:
                dir = self.dirs[i]
                index -= curr_idx_sum
                break
            curr_idx_sum += length
                

        return dir, index

    def __getitem__(self, index):
        dir, index = self.get_dir_index(index)
        
        item = np.load(dir + "/" + str(index) + ".npy", allow_pickle=True).tolist()

        events = np.asarray(item["events"])

        frame_transform = transforms.Compose([# transforms.Denoise(filter_time=0.01),
                                       transforms.ToVoxelGrid(sensor_size=sensor_size,
                                                          n_time_bins=self.num_bins_per_frame)
                                      ])

        events = frame_transform(events)

        mask = torch.from_numpy(np.asarray([item["mask"]])).to(torch.int64)

        one_hot_mask = torch.nn.functional.one_hot(mask, num_classes=num_classes).transpose(1, 3).transpose(2, 3) # Conversion into Batch, Channels, H, W

        # Downsize the mask.
        resized_mask = tt.functional.resize(one_hot_mask, self.output_size, antialias=True)
        
        return torch.from_numpy(events).to(torch.float), resized_mask.squeeze()

    def get_original_mask(self, index):
        dir, index = self.get_dir_index(index)
        
        item = np.load(dir + "/" + str(index) + ".npy", allow_pickle=True).tolist()
        mask = torch.from_numpy(np.asarray([item["mask"]])).to(torch.int64)
        return mask

    def __len__(self) -> int:
        return self.length # - self.start_idx

    # def get_item(self, index):
    #     item = np.load(self.dir + "/" + str(index) + ".npy", allow_pickle=True).tolist()
    #     return item


In [4]:
dirs = ["./data/EVIMO/left_cam/scene13_test5",
       "./data/EVIMO/left_cam/scene14_test3",
       "./data/EVIMO/left_cam/scene15_test1"]

In [5]:
dataset = EVIMOMask(dirs=dirs, output_size=output_size, num_bins_per_frame=num_bins_per_frame)

In [6]:
len(dataset)

2181

In [7]:
bins, mask = dataset[2100]
bins.shape, mask.shape

(torch.Size([8, 1, 480, 640]), torch.Size([25, 480, 640]))

In [8]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 4, kernel_size=5, padding=2)
        self.pool1 = nn.MaxPool2d(2)
        self.lif1 = snn.Leaky(beta=beta)
        self.conv2 = nn.Conv2d(4, 4, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(2)
        self.lif2 = snn.Leaky(beta=beta)
        
        self.upconv1 = nn.ConvTranspose2d(4, 8, kernel_size=2, stride=2)
        #self.unpool1 = nn.MaxUnpool2d(2)
        #nn.Flatten(),
        ##nn.Linear(64*4*4, 10),
        self.lif3 = snn.Leaky(beta=beta)
        self.upconv2 = nn.ConvTranspose2d(8, num_classes, kernel_size=2, stride=2)
        #self.unpool2 = nn.MaxUnpool2d(2)
        self.lif4 = snn.Leaky(beta=beta, output=True)

    
    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        x = self.conv1(x)
        x = self.pool1(x)
        spk1, mem1 = self.lif1(x, mem1)
        x = self.pool2(self.conv2(spk1))
        spk2, mem2 = self.lif2(x, mem2)

        #print(spk2.shape)
        x = self.upconv1(spk2)
        #print(x.shape)
        spk3, mem3 = self.lif3(x, mem3)

        x = self.upconv2(spk3)
        spk4, mem4 = self.lif4(x, mem4)

        return spk4, mem4

In [9]:
model = Model().to(device)

In [10]:
def forward_pass(data):
    mem_rec = []
    spk_rec = []
    utils.reset(model)  # resets hidden states for all LIF neurons in net
    
    data = data.transpose(0, 1) # num_steps, batch_size, C, H, W
    
    for step in range(num_bins_per_frame):
      spk_out, mem_out = model(data[step])
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

    return torch.stack(spk_rec), torch.stack(mem_rec)

In [11]:
spk_rec, mem_rec = model(bins)

In [12]:
spk_rec.shape

torch.Size([8, 25, 480, 640])

In [13]:
loss_fn = SF.ce_temporal_loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, betas=(0.9, 0.999))

In [14]:
trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
for epoch in range(num_epochs):
    train_batch = tqdm(iter(trainloader), desc=f"Epoch {epoch}")
    for data, masks in train_batch:
        data = data.to(device).to(torch.float) # Data currently in batch_size, num_steps, Channels, H, W
        masks = masks.to(device).to(torch.float)

        model.train()
        spk_rec, _ = forward_pass(data)
        loss_val = loss_fn(spk_rec, masks)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

Epoch 0:   0%|          | 0/546 [00:00<?, ?it/s]