diff --git a/sweepai/core/lexical_search.py b/sweepai/core/lexical_search.py index eb49179478..387cf5afb3 100644 --- a/sweepai/core/lexical_search.py +++ b/sweepai/core/lexical_search.py @@ -2,6 +2,7 @@ import multiprocessing import re from collections import Counter, defaultdict +from multiprocessing import Manager from dataclasses import dataclass from math import log @@ -194,21 +195,32 @@ def prepare_index_from_snippets( if ticket_progress: ticket_progress.search_progress.indexing_total = len(all_docs) ticket_progress.save() - all_tokens = [] + # all_tokens will be managed by the multiprocessing Manager + # all_tokens = [] try: - # use 1/4 the max number of cores - with multiprocessing.Pool(processes=multiprocessing.cpu_count() // 4) as p: - for i, document_token_freq in tqdm(enumerate( - p.imap(compute_document_tokens, [doc.content for doc in all_docs]) - )): - all_tokens.append(document_token_freq) - if ticket_progress and i % 200 == 0: - ticket_progress.search_progress.indexing_progress = i - ticket_progress.save() - for doc, document_token_freq in tqdm(zip(all_docs, all_tokens), desc="Indexing"): - index.add_document( - title=doc.title, token_freq=document_token_freq # snippet.denotation - ) + manager = Manager() + all_tokens = manager.list() + + def add_document_worker(doc_title, doc_content, shared_index): + token_freq = compute_document_tokens(doc_content) + shared_index.add_document(title=doc_title, token_freq=token_freq) + return token_freq + + shared_index = manager.list() + + try: + with multiprocessing.Pool(processes=multiprocessing.cpu_count() // 4) as pool: + results = pool.starmap_async(add_document_worker, [(doc.title, doc.content, shared_index) for doc in all_docs]) + pool.close() + pool.join() + # Update the main index and progress after all processes are done + for document_token_freq in results.get(): + all_tokens.append(document_token_freq) + if ticket_progress: + ticket_progress.search_progress.indexing_progress += 1 + ticket_progress.save() + except Exception as e: + logger.exception(e) except FileNotFoundError as e: logger.exception(e)