In [7]:
import numpy as np
import pickle
import torch
from random import randrange
from datetime import timedelta
from chunkified_npset import ChunkifiedDataset

In [5]:
class CHMainDataset(torch.utils.data.Dataset):
    def __init__(self, np_set, gen_tables):
        self.in_width = 128
        self.out_width = 64
        self.inner_offset = int((128 - 64) / 2)
        self.np_set = np_set
        self._gen_tables(gen_tables)

    def _gen_tables(self, gen_tables):
        idx_2_time = None
        with open('uk_data_np/idx_2_time', 'rb') as i2t_f:
            idx_2_time = pickle.load(i2t_f)
        i2t_f.close()

        self.corner_and_idxs = gen_tables(idx_2_time)

    def __len__(self):
        return len(self.corner_and_idx)

    def _crop(self, section, corner, width):
        return section[:, corner[0]:corner[0]+width, corner[1]:corner[1]+width]
    
    def _in_crop(self, section, corner):
        return self._crop(section, corner, self.in_width)

    def _out_crop(self, section, corner):
        corner[0] += self.inner_offset
        corner[1] += self.inner_offset

        return self._crop(section, corner, self.out_width)
    
    def __getitem__(self, idx):
        corner, np_idx = self.corner_and_idxs[idx]
        in_section = self.np_set[np_idx : np_idx+12]
        in_section = self._in_crop(in_section, corner)

        out_section = self.np_set[np_idx+12 : np_idx+36]
        out_section = self._out_crop(out_section, corner)

        return (in_section, out_section)

In [27]:
class SingleCropGenerator:
    def __init__(self, raw_w, raw_h):
        self.period = 36
        self.crop_width = 128
        self.raw_w = raw_w
        self.raw_h = raw_h

    def _generate_corner(self):
        corner = (randrange(0, self.raw_w - self.crop_width), randrange(0, self.raw_h - self.crop_width))
        return corner

    def __call__(self, idx_2_time):
        table = []
        for i in range(len(idx_2_time)):
            end_bound = i + self.period
            if end_bound >= len(idx_2_time): #no more 3 hour intervals to check!
                break
            
            if idx_2_time[i] + timedelta(hours=3) != idx_2_time[end_bound]: #if period # of points down the line isn't 3 hours exactly in the future
                continue

            table.append((self._generate_corner(), i))
        return table


In [28]:
with open('uk_data_np/idx_2_time', 'rb') as i2t_f:
    idx_2_time = pickle.load(i2t_f)
i2t_f.close()

In [30]:
gen = SingleCropGenerator(891, 6969)
gen(idx_2_time)

[((36, 1415), 0),
 ((493, 284), 1),
 ((140, 1419), 2),
 ((369, 2375), 3),
 ((340, 4966), 4),
 ((154, 5400), 5),
 ((340, 5275), 6),
 ((224, 4522), 7),
 ((686, 3824), 8),
 ((104, 4370), 9),
 ((511, 3400), 10),
 ((188, 1999), 11),
 ((452, 2702), 12),
 ((684, 1755), 13),
 ((618, 3295), 14),
 ((373, 5862), 15),
 ((334, 1978), 16),
 ((192, 5405), 17),
 ((295, 3788), 18),
 ((495, 4585), 19),
 ((334, 250), 20),
 ((284, 3813), 21),
 ((161, 321), 22),
 ((479, 5670), 23),
 ((131, 2111), 24),
 ((262, 6353), 25),
 ((626, 2904), 26),
 ((205, 2819), 27),
 ((356, 2622), 28),
 ((147, 4381), 29),
 ((673, 186), 30),
 ((408, 2424), 31),
 ((478, 4015), 32),
 ((623, 939), 33),
 ((270, 5728), 34),
 ((477, 680), 35),
 ((433, 102), 36),
 ((71, 3176), 37),
 ((52, 1435), 38),
 ((457, 5883), 39),
 ((467, 3365), 40),
 ((378, 1195), 41),
 ((115, 3169), 42),
 ((637, 2678), 43),
 ((634, 4448), 44),
 ((131, 4543), 45),
 ((748, 4493), 46),
 ((120, 5943), 47),
 ((24, 5257), 48),
 ((148, 6799), 49),
 ((184, 5675), 50),
 

In [4]:
dataset = CHMainDataset()
print(dataset)
dataset[0]

<__main__.CHMainDataset object at 0x000001B2CB3D7C48>


BadZipFile: File is not a zip file