In [None]:
import data
import time
import numpy as np
import tarfile
import io
import json
from scipy.ndimage import imread

In [None]:
class BatchBuffer:
    def __init__(self, offset, size):
        self.offset = offset
        self.size = size
        self.index = 0
    
    def read(self, num=None):
        raise NotImplementedError('Unimplemented')
        
    def seek(self, num, mode=0):
        if mode == 0:
            self.index = num
        elif mode == 1:
            self.index += num
        elif mode == 2:
            self.index = self.size+num
        else:
            raise NotImplementedError('Unimplemented')
        
    def tell(self):
        return self.index

class ArrayBuffer(BatchBuffer):
    def __init__(self, array, offset, size):
        if size < 0:
            size = len(array) - offset
        super().__init__(offset, size)
        self.array = array
    
    def read(self, num=None):
        if self.index >= self.size:
            return None
        
        if num is None:
            num = self.size - self.index
            
        ni = self.index+num
        if ni >= self.size:
            ni = self.size
        ret = self.array[self.offset+self.index:self.offset+ni]
        self.index += num
        return ret
    
class TarBuffer(BatchBuffer):
    def __init__(self, tar, offset, size):
        self.tar = tar
        self.members = tar.getmembers()[offset:-1]
        if size < 0:
            size = len(self.members)
        else:
            self.members = self.members[:size]
        super().__init__(offset, size)
        
        self.i = 0
    
    def read(self, num=None):
        if self.index >= self.size:
            return None
        
        if num is None:
            num = self.size - self.index
            
        ni = self.index+num
        if ni >= self.size:
            ni = self.size
            
        start_time = time.time()
            
        images = [1 - (imread(self.tar.extractfile(self.members[i]))/255) for i in range(num)]
        ret = np.array(images)
        self.index += num
        
        self.i += 1
        print ("%3dth, %ss" % (self.i, time.time()-start_time))
        
        return ret
        
    
def run_epoch(imgbuf, labelbuf, batchsize=100):
    imgbuf.seek(0)
    labelbuf.seek(0)
    epoch = 0
    while (epoch < 4):
        batch_x = imgbuf.read(batchsize)
        batch_y = labelbuf.read(batchsize)
        if batch_x is None:
            assert batch_y is None
            imgbuf.seek(0)
            labelbuf.seek(0)
            epoch += 1
            print("%dth epoch" % epoch)
            continue

def get_all_in_one():
    index_data, img, label = data.get_all('data/161020.tgz')
    trainimg = ArrayBuffer(img, 15000, -1)
    trainlabel = ArrayBuffer(label, 15000, -1)
    run_epoch(trainimg, trainlabel)

def get_stride():
    data_path = 'data/161020.tgz'
    
    index_data = []
    with tarfile.open(data_path, "r:*") as tar:
        print("tar opened for index")
        ft = tar.extractfile(tar.getmembers()[-1])
        ft_str = io.TextIOWrapper(ft)
        index_data.extend(json.load(ft_str))
        tar.close()
    
    label = data.get_label(index_data)
    
    tar = tarfile.open(data_path, "r:*")
    trainlabel = ArrayBuffer(label, 15000, -1)
    trainimg = TarBuffer(tar, 15000, -1)
    assert trainlabel.size == trainimg.size, "%d == %d" % (trainlabel.size, trainimg.size)
    print("buffer loaded")
    run_epoch(trainimg, trainlabel)

In [None]:
start_time = time.time()
get_all_in_one()
print("--- %s seconds ---" % (time.time() - start_time))

In [None]:
start_time = time.time()
get_stride()
print("--- %s seconds ---" % (time.time() - start_time))