## Pipeline to Save Embeddings for Wikipedia Articles

In [1]:
import math
import os
os.environ["TOKENIZERS_PARALLELISM"] = 'false'

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from torch.utils.data import DataLoader
import transformers
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util

device = torch.device('cuda')

### Parameters 
Change these if running on your machine. Can set some to download from HF Hub.

In [2]:
embedding_model = '/home/stefanwebb/models/llms/multi-qa-MiniLM-L6-cos-v1'
embeddings_path = '/home/stefanwebb/embeddings/wikimedia/wikipedia/20231101.en'

dataset_path = '/home/stefanwebb/data/wikimedia/wikipedia/20231101.en'
dataset_name = ''
dataset_split = 'train'

batch_size = 1024
num_data_workers = 4

### Chunking Strategy
* Chunks are per-paragraph, split with overlap if > 510 tokens.
* Chunks corresponding to headers and short sentences are filtered out.
* Is this what you call Semantic chunking?

In [3]:
document_encoder = SentenceTransformer(embedding_model).to(device)

In [4]:
# NOTE: May need extra space for special tokens, hence 510 not 512
def chunk_iterator(seq, max_seq_length=510, overlap=128):
    """
    Given a list of tokens, if it is greater than maximum, break into overlapping sections
    """
    
    if len(seq) <= max_seq_length:
        yield seq
    else:
        count_chunks = ((len(seq) - max_seq_length) + overlap - 1) // overlap + 1
        for idx in range(count_chunks):
            yield seq[(idx*overlap):min(len(seq), (idx*overlap) + max_seq_length)]

In [5]:
def chunk_by_paragraph(articles):
    """
    Forms chunks by breaking document into paragraphs, and then into sub-paragraphs if necessary.
    """
    chunks = []
    for article in articles['text']:
        # print('article', article)
        paragraphs = article.split('\n\n')
        paragraphs = [p for p in paragraphs if p != '' and p[0].isalpha()]
        # TODO: Can I avoid tokenizing sentences twice: once here and once in model.encode?
        # model.encode doesn't seem to be able to take tokens, only strings
        # I could dig into source code in library to find a solution...
        # Alternatively, just do fixed length chunking with overlap
        
        # Filter by string
        if len(paragraphs) > 0:
            tokens = document_encoder.tokenizer(paragraphs)['input_ids']
            tokens = [x for p in tokens for x in list(chunk_iterator(p)) if len(x) > 7]
            paragraphs = document_encoder.tokenizer.batch_decode(tokens)

        chunks += paragraphs

    return {'chunks': chunks}

#### Load dataset
* An issue I noticed is that formulae, tables, etc. are missing from data which means some paragraphs end abruptly and this info is lost


In [40]:
dataset = load_dataset(dataset_path, dataset_name, data_files=["train-00000-of-00041.parquet"], split=dataset_split, streaming=True)

In [41]:
chunked_ds = dataset.map(chunk_by_paragraph, batched=True, batch_size=4048, remove_columns=dataset.column_names)
chunked_pt = DataLoader(chunked_ds, batch_size=1024, num_workers=4)

### Time Embedding of Chunks
* Starting off with one file for testing purposes
* Initial pause of ~40 secs is dataloader filling its buffer
* Ignore warning "Token indices sequence length is longer than the specified maximum sequence length". We tokenize longer chunks before breaking them up with overlap so the ones input to model are < 510 length
* Throughput can be further optimized so GPU is constantly occupied and not blocked by synchronous operations or data starved.
* About 32 mins for 1/41 data files => 22 hours upper bound
* About 4gb per Parquet file
* I'll speed up by increasing num of data workers, and save embeddings at float16
* TODO: Experiment with 8-bit and 1-bit embeddings from SentenceTransformers

In [None]:
# Note: I find this an easier way to construct schema
table = pa.table([
            pa.array([
               f"The quick brown fox jumped over the log.",
               f"Colorless dreams sleep furiously."       
         ]),
            pa.array([
               torch.randn(384).numpy(),
               torch.randn(384).numpy()
            ])
         ], names=["chunks", "embeddings"])

with pq.ParquetWriter('testing.tmp.parquet', table.schema) as writer:
    for idx, batch in enumerate(chunked_pt):
        # Embed chunks
        # TODO: Cast to lower float precision for saving? np.float16?
        chunks = batch['chunks']
        embeddings = document_encoder.encode(chunks, batch_size=1024, show_progress_bar=True)


        # Stream output
        # TODO: Collate batches before saving to reduce I/O ops?
        # TODO: Is write_table synchronous or asynchronous? I.e. is this holding up GPU?
        table = pa.table([
            pa.array(chunks),
            pa.array(list(embeddings))
         ], names=["chunks", "embeddings"])
        writer.write_table(table)
        
        # DEBUG
        # if idx == 10:
        #     break

### Embed Entire Wikipedia and Save
* Changed my mind on quantization. I'll save as `np.float32` and do quantization in FAISS.

In [6]:
table = pa.table([
                pa.array([
                f"The quick brown fox jumped over the log.",
                f"Colorless dreams sleep furiously."       
            ]),
                pa.array([
                np.zeros((384), dtype=np.float32),
                np.zeros((384), dtype=np.float32)
                ])
            ], names=["chunks", "embeddings"])

print(table.schema)

chunks: string
embeddings: list<item: float>
  child 0, item: float


In [7]:
files = [f"train-{idx:05d}-of-00041.parquet" for idx in range(41)]

for file in files[1:]:
    # Count approximate batches for progress (10 chunks per article on average)
    metadata = pq.ParquetFile(os.path.join(dataset_path, file)).metadata
    count_batches = math.ceil(metadata.num_rows / 1024) * 10
    del metadata

    # Load input stream
    dataset = load_dataset(dataset_path, dataset_name, data_files=[file], split=dataset_split, streaming=True)
    chunked_ds = dataset.map(chunk_by_paragraph, batched=True, batch_size=batch_size*4, remove_columns=dataset.column_names)
    chunked_pt = DataLoader(chunked_ds, batch_size=batch_size)
    
    # Note: I find this an easier way to construct schema
    table = pa.table([
                pa.array([
                f"The quick brown fox jumped over the log.",
                f"Colorless dreams sleep furiously."       
            ]),
                pa.array([
                torch.randn(384).numpy(),
                torch.randn(384).numpy()
                ])
            ], names=["chunks", "embeddings"])

    with pq.ParquetWriter(os.path.join(embeddings_path, file), table.schema) as writer:
        print(f'Processing {file}')
        for idx, batch in enumerate(chunked_pt):
            # Embed chunks
            chunks = batch['chunks']
            embeddings = document_encoder.encode(chunks, batch_size=batch_size, show_progress_bar=False).astype(np.float32)

            # Stream output
            table = pa.table([
                pa.array(chunks),
                pa.array(list(embeddings))
            ], names=["chunks", "embeddings"])
            writer.write_table(table)

            if idx % 300 == 0:
                print(f'\t batch {idx} of {count_batches}')

    # Not sure if following is necessary, just in case.
    del chunked_ds
    del chunked_pt

    print('')

Processing train-00001-of-00041.parquet


Token indices sequence length is longer than the specified maximum sequence length for this model (1310 > 512). Running this sequence through the model will result in indexing errors


	 batch 0 of 1530
	 batch 300 of 1530
	 batch 600 of 1530
	 batch 900 of 1530
	 batch 1200 of 1530

Processing train-00002-of-00041.parquet
	 batch 0 of 1530
	 batch 300 of 1530
	 batch 600 of 1530
	 batch 900 of 1530
	 batch 1200 of 1530

Processing train-00003-of-00041.parquet
	 batch 0 of 1530
	 batch 300 of 1530
	 batch 600 of 1530
	 batch 900 of 1530
	 batch 1200 of 1530

Processing train-00004-of-00041.parquet
	 batch 0 of 1530
	 batch 300 of 1530
	 batch 600 of 1530
	 batch 900 of 1530
	 batch 1200 of 1530

Processing train-00005-of-00041.parquet
	 batch 0 of 1530
	 batch 300 of 1530
	 batch 600 of 1530
	 batch 900 of 1530

Processing train-00006-of-00041.parquet
	 batch 0 of 1530
	 batch 300 of 1530
	 batch 600 of 1530
	 batch 900 of 1530

Processing train-00007-of-00041.parquet
	 batch 0 of 1530
	 batch 300 of 1530
	 batch 600 of 1530
	 batch 900 of 1530

Processing train-00008-of-00041.parquet
	 batch 0 of 1530
	 batch 300 of 1530
	 batch 600 of 1530
	 batch 900 of 1530

Proc

### Scratchpad

In [57]:
table = pa.table([
                pa.array([
                f"The quick brown fox jumped over the log.",
                f"Colorless dreams sleep furiously."       
            ]),
                pa.array([
                np.zeros((384), dtype=np.float32),
                np.zeros((384), dtype=np.float32)
                ])
            ], names=["chunks", "embeddings"])

print(table.schema)

chunks: string
embeddings: list<item: halffloat>
  child 0, item: halffloat
