In [1]:
import itertools
import signal
import logging

sigint_handler = signal.getsignal(signal.SIGINT)

def critical(f=None):
    if f is not None:
        it = iter(f)
    else:
        f = itertools.count()
    signal_received = ()
    
    def handler(sig, frame):
        nonlocal signal_received
        signal_received = (sig, frame)
        logging.warning('SIGINT received. Delaying KeyboardInterrupt.')

    while True:
        try:
            signal.signal(signal.SIGINT, handler)
            yield next(it)
            signal.signal(signal.SIGINT, sigint_handler)
            if signal_received:
                sigint_handler(signal_received)
        except StopIteration:
            break

In [2]:
import time
for i in critical(range(10)):
    print('Don\'t stop here!')
    time.sleep(0.5)
    print('ok now')

Don't stop here!
ok now
Don't stop here!




ok now


KeyboardInterrupt: 

In [16]:
def stateful(states):

    def wrapper(cls):
        def state_dict(self):
            return {s: getattr(self, s) for s in states}

        def load_state_dict(self, state):
            for s in states:
                setattr(self, s, state[s])

        cls.state_dict = state_dict
        cls.load_state_dict = load_state_dict
        return cls

    return wrapper

In [17]:
import math
import queue
import random
import threading

@stateful(['batch_size', 'index', 'pos'])
class PrefetchIter:
    """Iterator on data and labels, with states for save and restore."""

    def __init__(self, data, *label, length=None, batch_size=1):
        self.data = data
        self.label = label
        self.batch_size = batch_size
        self.queue = queue.Queue(maxsize=8)
        self.length = length if length is not None else len(data)
        
        assert all(self.length == len(lab) for lab in label), \
            'data and label must have same lengths'

        self.index = list(range(len(self)))
        random.shuffle(self.index)
        self.thread = None
        self.pos = 0

    def __len__(self):
        return math.ceil(self.length / self.batch_size)

    def __iter__(self):
        return self

    def __next__(self):
        if self.thread is None:
            self.thread = threading.Thread(target=self.produce)
            self.thread.start()

        if self.pos >= len(self.index):
            self.thread.join()
            raise StopIteration

        self.pos += 1
        return self.queue.get()

    def produce(self):
        for i in range(self.pos, len(self.index)):
            index = self.index[self.pos]

            bs = self.batch_size

            if callable(self.data):
                data_batch = self.data(index * bs, (index + 1) * bs)
            else:
                data_batch = self.data[index*bs:(index+1)*bs]

            label_batch = [label[index*bs:(index+1)*bs]
                           for label in self.label]
            self.queue.put([data_batch] + label_batch)

In [21]:
import time

def data(low, high):
    time.sleep(1)
    return low

it = PrefetchIter(data, length=32)

print(next(it))
print(it.state_dict())

for data in it:
    print(data)
    time.sleep(1)

[0]
{'batch_size': 1, 'index': [0, 25, 14, 4, 22, 20, 31, 19, 24, 29, 5, 10, 12, 13, 15, 3, 16, 17, 6, 23, 18, 26, 9, 1, 21, 7, 8, 11, 28, 30, 27, 2], 'pos': 1}
[25]
[14]
[14]
[4]


KeyboardInterrupt: 

In [20]:
it.pos

2