In [1]:
import os
import ray

In [None]:
class SegmentCorpusLoader:
    def __init__(
        self, tokenizer, max_seq_length: int, corpus_path: str, cache_dir: str = "cache/cached_corpus_sectors"
    ):
        self.tokenizer = tokenizer
        self.corpus_path = corpus_path
        self.cache_dir = cache_dir
        self.max_seq_length = max_seq_length
        self.sent_segmenter = SentenceSegmenter(tokenizer, max_seq_length)

        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)

        if "OMP_NUM_THREADS" not in os.environ.keys():
            os.environ["OMP_NUM_THREADS"] = str(os.cpu_count() // 2)

    def load_sector(self, sector_id):
        if self.cache_dir:

            cache_path = os.path.join(self.cache_dir, f"{sector_id}_cache.pkl")

            if os.path.exists(cache_path):
                try:
                    print("Loadding Cache")
                    processed_docs = torch.load(cache_path)
                    return processed_docs
                except:
                    print("File Corrupted. Data will be re-processed")

        # processing data
        with open(os.path.join(self.corpus_path, str(sector_id) + ".jsonl"), "r") as f:
            data = f.readlines()

        processed_docs = []

        print("Processing Data. Takes about 10 mins")

        # multi-processing
        ray_objs = []
        step_size = len(data) // int(os.environ["OMP_NUM_THREADS"])
        for i in range(0, len(data), step_size):
            ray_objs.append(_process.remote(self.sent_segmenter, data[i:i + step_size], i))

        for i in range(len(ray_objs)):
            processed_docs.extend(ray.get(ray_objs[i]))

        if self.cache_dir:
            print("Saving Into Cache")
            torch.save(processed_docs, cache_path)

        return processed_docs