In [1]:
import os
import ray
import numpy as np
from transformers import RobertaTokenizer
from corpus_loader import SentenceSegmenter, CorpusLoader

In [2]:
ray.init()

2020-02-18 09:27:18,032	INFO resource_spec.py:216 -- Starting Ray with 14.89 GiB memory available for workers and up to 7.45 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).


{'node_ip_address': '169.237.10.101',
 'redis_address': '169.237.10.101:52300',
 'object_store_address': '/tmp/ray/session_2020-02-18_09-27-18_030104_25589/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2020-02-18_09-27-18_030104_25589/sockets/raylet',
 'webui_url': None,
 'session_dir': '/tmp/ray/session_2020-02-18_09-27-18_030104_25589'}

In [3]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

In [4]:
corpus_loader = CorpusLoader(tokenizer, max_seq_length=128, corpus_path="/data/SECTOR/")

In [5]:
class CorpusIterator:
    def __init__(self, processed_docs):
        self.processed_docs = processed_docs
        self.total_num_docs = len(processed_docs)

    def __iter__ (self):
        # shuffle the indices
        indices = np.arange(self.total_num_docs)
        np.random.shuffle(indices)

        for doc_index in indices:
            # randomly sample a document
            doc = self.processed_docs[doc_index]

            for i, segment in enumerate(doc):
                # output if the segment is the start of the document
                yield segment, i==0

In [6]:
class CorpusBatchIterator:
    def __init__(self, tokenizer, corpus_path:str, batch_size:int, max_seq_length:int, rank:int = 0):
        """
        Args:
            corpus_path: directory path to store the corpus sectors
            rank: for distributed learning.
        """ 
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.current_sector_id = rank
        self.corpus_loader = CorpusLoader(tokenizer, 
                                          max_seq_length=max_seq_length, 
                                          corpus_path=corpus_path)
        self.total_num_sectors = len(os.listdir(corpus_path))
        
        # process the data and save it into cache
    
    def __iter__(self):
        iterators = self.create_corpus_iterators(self.current_sector_id)
        
        while True:
            try:
                # TODO: extend it with Ray
                batch = [next(iterators[i]) for i in range(self.batch_size)]
                yield batch
                
            except StopIteration:
                # after the iterator finishes, load the next sector
                # update self.current_sector_id
                self.current_sector_id = (rank + 1) % self.total_num_sectors
                iterators = self.create_corpus_iterators(self.current_sector_id)
                
    def create_corpus_iterators(self, corpus_sector_id):
        processed_docs = self.corpus_loader.load_sector(self.current_sector_id)
        iterators = [iter(CorpusIterator(processed_docs)) for i in range(self.batch_size)]
        return iterators

In [7]:
# processed_docs = corpus_loader.load_sector(0)

In [8]:
# iterator = CorpusIterator(processed_docs)

In [9]:
corpus_iter = CorpusBatchIterator(tokenizer, corpus_path="/data/en-corpus/SECTOR/", batch_size=2, max_seq_length=128)

In [10]:
corpus_iter = iter(corpus_iter)

In [21]:
next(corpus_iter)

[(['The',
   'Ġdeep',
   'Ġlove',
   'Ġfor',
   'ĠPis',
   'g',
   'ah',
   'Ġis',
   'Ġevident',
   'Ġin',
   'Ġthe',
   'Ġsheer',
   'Ġnumbers',
   'Ġof',
   'Ġvolunteers',
   'Ġwilling',
   'Ġto',
   'Ġput',
   'Ġin',
   'Ġthe',
   'Ġwork',
   'Ġthat',
   'Ġshould',
   'Ġbe',
   'Ġperformed',
   'Ġby',
   'Ġthe',
   'Ġfederal',
   'Ġland',
   'Ġagency',
   '.',
   'Buy',
   'ĠPhoto',
   'ĠTr',
   'acey',
   'ĠArmstrong',
   'Ġspl',
   'ashes',
   'Ġthrough',
   'Ġa',
   'Ġcreek',
   'Ġas',
   'Ġshe',
   'Ġdemonstrates',
   'Ġone',
   'Ġof',
   'Ġthe',
   'Ġmany',
   'Ġmountain',
   'Ġbiking',
   'Ġtrails',
   'Ġin',
   'Ġthe',
   'ĠBent',
   'ĠCreek',
   'ĠExperimental',
   'ĠForest',
   'Ġon',
   'ĠWednesday',
   ',',
   'ĠSept',
   '.',
   'Ġ27',
   ',',
   'Ġ2017',
   '.',
   'Ġ(',
   'Photo',
   ':',
   'ĠAngel',
   'i',
   'ĠWright',
   '/',
   'aw',
   'right',
   '@',
   'c',
   'itizen',
   '-',
   'times',
   '.',
   'com',
   ')',
   'Ċ',
   'Ċ',
   'In',
   'ĠNovember',
 

In [None]:
os.cpu_count()

In [None]:
list(range(0, 11, 10//2))