In [2]:
import numpy as np
import os
import torch
from tqdm import tqdm
import zipfile

from matplotlib import pyplot as plt

from datasets.evimo_dataset import EVIMODataset
import dataloader.base as base

from utils.masks_to_boxes import masks_to_boxes, detection_rate
from utils.gpu import moveToGPUDevice
from utils.rbase import RBase

from torchvision import transforms
from tensorboardX import SummaryWriter
from torch.utils.data.dataloader import default_collate
import torch.nn as nn 
from PIL import Image

In [3]:
def my_collate(batch):
    batch = list(filter (lambda x:x is not None, batch))
    if not batch:
        return None

    return default_collate(batch)

class Runner(RBase):

    def __init__(self, crop, maxBackgroundRatio, datasetType, datafile, 
                    checkpoint, modeltype, 
                    log_config, general_config, 
                    maskDir, incrementalPercent,
                    saveImages, saveImageInterval, imageDir, imageLabel=""):
        super().__init__(datafile, log_config, general_config)
       
        self.output_dir = self.log_config.getOutDir() 
        self.genconfigs = snn.params(general_config)
        self.checkpoint = checkpoint
        self.modeltype = modeltype
        self.maskDir = maskDir
        self.incrementalPercent = incrementalPercent
        self.saveImages = True
        self.saveImageInterval = saveImageInterval
        self.imageDir = '/scratch/mclerico/.'
        self.imageLabel = imageLabel

        if(datasetType == "EVIMO"):
            #database = base.EVIMODatasetBase(datafile, self.genconfigs, self.maskDir, crop, maxBackgroundRatio, incrementalPercent)
            database = EVIMODataset(datafile,100, False)
            print("EVIMO used")
        elif(datasetType == "MOD"):
            #database = base.MODDatasetBase(datafile, self.genconfigs, self.maskDir, crop, maxBackgroundRatio, incrementalPercent)
            database = MODDataset(datafile + 'room1obj1-001.hdf5', datafile+'seq_room1_obj1/masks/', datafile+'seq_room1_obj1_neuroscience.h5', 100)
            print("MOD used")
        else:
            raise Exception("Only EVIMO or MOD datasets with hdf5 format generated by preprocessing scripts handled with this code.")

        num_workers = self.genconfigs['hardware']['readerThreads']
        batch_size = self.genconfigs['batchsize']
        self.loader = torch.utils.data.DataLoader(database, batch_size=8, shuffle=False, num_workers=8, collate_fn=my_collate, drop_last = False)
        self.tb_writer = SummaryWriter(self.output_dir)    

    def test(self):
        #self._loadNetFromCheckpoint(self.checkpoint, self.modeltype)
        self.net = self.net.eval()

        total_IOU = 0
        total_input_IOU = 0
        
        scalar_i = 0
        tot_frames = 0
        device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

        if self.saveImages and not os.path.exists(self.imageDir):
                os.mkdir(self.imageDir)
                
        with torch.no_grad():
            for i, data in enumerate(self.loader):
                if (data == None):
                    continue
                data = moveToGPUDevice(data, device, None)
                oms_spikes_input = data['oms_spike_tensor']
                oms_spikes_masked = data['oms_mask']
                spikes_input = data['dvs_spike_tensor']

                ioucriterion = snn.loss(self.genconfigs).to(self.device) 

                #calculate metrics
                tot_frames += spikes_input.shape[0]

                #spike_pred_2D = torch.sum(spike_pred, axis = (1,4))
                oms_mask_2D = torch.sum(oms_spikes_masked, axis = (1,4))
                boxes_gt =masks_to_boxes(oms_mask_2D)

                #oms_mask_2D[oms_mask_2D>1] = 1
                oms_spike_2D = torch.sum(oms_spikes_input, axis=(1,4))
                boxes_pred =  masks_to_boxes(oms_spike_2D)

                for i in range(oms_mask_2D.shape[0]):
                    print(detection_rate(boxes_gt[i], boxes_pred[i]))
                #print("PRED IOU", iou)
                input_iou = ioucriterion.getIOU(oms_spike_2D, oms_mask_2D)
                print("OMS INPUT IoU", input_iou)
                print("-------------------------------------------------")

                total_input_IOU += input_iou*spikes_input.shape[0]
                scalar_i += 1
                #self.tb_writer.add_scalar('iou', iou.item(), scalar_i)
                """ 
                if self.saveImages and (i)%self.saveImageInterval== 0:
                    spikes_maskednp = np.array(oms_spikes_masked.detach().cpu())
                    #spikesPred_np = np.array(spikepred.detach().cpu())
                    spikesInput_np = np.array(oms_spikes_input.detach().cpu())

                    #print("save to: ", self.output_dir)

                    for batch in range(0,spikes_input.shape[0]):
                        curr_num = i+batch            
                        #im = Image.fromarray(np.uint8(np.sum(spikesPred_np[batch,:,:,:,:], axis=(0,3))*255))
                        #im.save(os.path.join(self.imageDir,"_pred_epoch{}".format(batch) + self.imageLabel + ".jpg"))

                        im2 = Image.fromarray(np.uint8( 255 - np.sum(spikes_maskednp[batch,:,:,:,:], axis=(0,3))*255))
                        im2.save(os.path.join(self.imageDir,"_ideal_epoch{}".format(curr_num) + self.imageLabel + ".jpg"))

                        im3 = Image.fromarray(np.uint8(255- np.sum(spikesInput_np[batch,:,:,:,:], axis=(0,3))*255))
                        im3.save(os.path.join(self.imageDir,"_input_epoch{}".format(curr_num) + self.imageLabel + ".jpg"))
                """

        #print("save to: ", self.output_dir)

        if self.saveImages:
            print("saving images to", os.getcwd(), self.imageDir)

        print("mean OMS IoU for {} batches of frames".format(tot_frames), total_input_IOU/tot_frames)
        #print("mean DVS IoU for {} batches of frames".format(tot_frames), total_IOU/tot_frames)

