In [1]:
import os
import ray
import json
import tqdm
import torch
import logging
import numpy as np
from transformers import AutoTokenizer
from typing import List

from full_doc_loader import FullDocLoader

logger = logging.getLogger(__name__)

In [179]:
class FullDocIterator:
    def __init__(self, processed_docs, sep_token: str, max_seq_length:int, allow_cross_doc=True):
        self.processed_docs = processed_docs
        self.max_seq_length = max_seq_length
        self.allow_cross_doc = allow_cross_doc
        self.total_num_docs = len(processed_docs)
        self.sep_token = sep_token

    def __iter__ (self):
        # shuffle the indices
        indices = np.arange(self.total_num_docs)
        np.random.shuffle(indices)
        current_seq = []
        history_pointer = 0

        for doc_index in indices:
            # randomly sample a document
            doc = self.processed_docs[doc_index]
            
            if not self.allow_cross_doc:
                for i in range(0, len(doc), self.max_seq_length):
                    yield doc[i:i+self.max_seq_length], i==0
            else:
                while history_pointer < len(doc):
                    # history pointer for the current document
                    next_pointer = history_pointer + self.max_seq_length - len(current_seq)
                    doc_seg = doc[history_pointer:next_pointer]
                    current_seq.extend(doc_seg)
                    
                    if_start_doc = history_pointer == 0
                    history_pointer = history_pointer + len(doc_seg)
                    print(history_pointer)
                    
                    if len(current_seq) == self.max_seq_length:
                        yield current_seq, if_start_doc
                        current_seq = []
                    
                # if the document is over
                history_pointer = 0
                if len(current_seq) > 0:
                    current_seq = current_seq + [self.sep_token]

class FullDocBatchIterator:
    def __init__(self):
        pass


In [180]:
list(range(0, 10-5, 5))

[0]

In [193]:
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

In [194]:
corpus_loader = FullDocLoader(tokenizer, corpus_path="/data/en-corpus/SECTOR/")

In [195]:
test_doc = [tokenizer.tokenize("hello, I am good. it is not good"), tokenizer.tokenize("I am fine. a is abcde.")]

In [196]:
# processed_docs = corpus_loader.load_sector(0)

In [197]:
iterator = FullDocIterator(test_doc, tokenizer.sep_token, max_seq_length=4, allow_cross_doc=False)

In [198]:
iterator = iter(iterator)

In [205]:
next(iterator)

StopIteration: 

In [132]:
len(processed_docs[:100])

100