## Segment legal documents into passages (chunks)

##### Install Prerequisites

In [3]:
%%capture

!pip install tiktoken==0.3.3
!pip install tqdm

#### Imports 

In [4]:
from tqdm import tqdm
import tiktoken
import requests
import logging
import os

##### Setup logging

In [5]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [6]:
logger.info(f'Using requests=={requests.__version__}')

Using requests==2.28.2


#### Setup essentials 

##### Initialize encoder
`cl100k_base` is the encoding used by OpenAI's `gpt-4`, `gpt-3.5-turbo` and `text-embedding-ada-002` models

In [7]:
encoding = tiktoken.get_encoding('cl100k_base')

In [8]:
DOC_DIR_PATH = './data/docs'

#### Encode chunks using Tiktoken and determine token count

In [9]:
def num_tokens_from_doc(doc: str) -> int:
    """
    Returns the number of tokens in a text string.
    """
    num_tokens = len(encoding.encode(doc))
    return num_tokens

In [10]:
CHUNK_SIZE = 768

In [11]:
def doc_iterator(dir_path: str):
    for root, _, filenames in os.walk(dir_path):
        for filename in filenames:
            file_path = os.path.join(root, filename)
            if os.path.isfile(file_path):
                with open(file_path, 'r') as file:
                    file_contents = file.read()
                    yield filename, file_contents

#### Segment docs into passages (chunks)

Note: Here, we drop the last chunk if the chunk size (number of tokens) < 512

In [12]:
%%time 

n_docs = 0
n_passages = 0

for doc_name, doc in tqdm(doc_iterator(DOC_DIR_PATH)):
    doc_id = doc_name.split('.')[0]
    tokens = encoding.encode(doc)
    chunks = []
    chunk_id = 1
    n_docs += 1
    for i in range(0, len(tokens), CHUNK_SIZE):
        chunk_tokens = tokens[i: i+CHUNK_SIZE]
        if not len(chunk_tokens) < 512:
            chunk = encoding.decode(chunk_tokens)
            with open(f'./data/chunks/{doc_id}_{chunk_id}', 'w') as f:
                f.write(chunk)
            chunk_id += 1
            n_passages += 1
logger.info(f'{n_docs} documents segmented into {n_passages} passages')

50it [00:00, 90.15it/s]
50 documents segmented into 44 passages


CPU times: user 90.2 ms, sys: 11.7 ms, total: 102 ms
Wall time: 579 ms
