In [None]:
import torch
from torch.utils.data import DataLoader, IterableDataset
import pickle
import math
import random 
import numpy as np

X = 1
Y = 0

ROW = 0
COL = 1

def local_frame(m, output_shape, pose, fill=0):
    half_out_shape = np.array(output_shape)
    padded = np.pad(m,([half_out_shape[Y]]*2,[half_out_shape[X]]*2), mode='constant', constant_values=fill)
    return padded[pose[Y]:pose[Y] + output_shape[Y], pose[X]:pose[X] + output_shape[Y]]


class CoverageDataset(IterableDataset):
    def __init__(
            self, dataset, max_len=None, shuffle=False):
        with open(dataset, 'rb') as handle:
            data_list = pickle.load(handle)
        self.dataset = data_list
        self.max_len = max_len
        self.shuffle = shuffle

    def __iter__(self):
        length = len(self.dataset)
        if self.max_len:
            length = min(length, self.max_len)
        indices = list(range(length))
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            gen = self._dataset_generator(indices)
        else:  
            per_worker = int(math.ceil(length / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = worker_id * per_worker
            iter_end = min(iter_start + per_worker, length)
            gen = self._dataset_generator(indices[iter_start:iter_end])
        return gen
    
    def _dataset_generator(self, indices):
        if self.shuffle:
            random.shuffle(indices)
        for idx in indices:
            item = self.dataset[idx]
            n_agents = len(item['obs']['agents'])
            for a_id in range(n_agents):
                agent_state = item['obs']['agents'][a_id]['map']
                agent_pos = item['obs']['agents'][a_id]['pos']
                agent_obs = item['feats'][a_id]
                obs_map = agent_state[..., 0]
                obs_cov = agent_state[..., 1]
                label = torch.stack([obs_map, obs_cov], axis=0)

                weight = torch.ones_like(obs_map.shape)
                weight = local_frame(weight, obs_map.shape, agent_pos, fill=0)
                weight = np.stack([weight] * 2, axis=0)

                yield {'x': agent_obs, 'w': weight, 'y': label}

        

