Skip to content

Commit

Permalink
feat: Updated sweepai/core/lexical_search.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] committed Feb 23, 2024
1 parent 6430b99 commit 041ee3c
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions sweepai/core/lexical_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import multiprocessing
import re
from collections import Counter, defaultdict
from multiprocessing import Manager
from dataclasses import dataclass
from math import log

Expand Down Expand Up @@ -194,21 +195,32 @@ def prepare_index_from_snippets(
if ticket_progress:
ticket_progress.search_progress.indexing_total = len(all_docs)
ticket_progress.save()
all_tokens = []
# all_tokens will be managed by the multiprocessing Manager
# all_tokens = []
try:
# use 1/4 the max number of cores
with multiprocessing.Pool(processes=multiprocessing.cpu_count() // 4) as p:
for i, document_token_freq in tqdm(enumerate(
p.imap(compute_document_tokens, [doc.content for doc in all_docs])
)):
all_tokens.append(document_token_freq)
if ticket_progress and i % 200 == 0:
ticket_progress.search_progress.indexing_progress = i
ticket_progress.save()
for doc, document_token_freq in tqdm(zip(all_docs, all_tokens), desc="Indexing"):
index.add_document(
title=doc.title, token_freq=document_token_freq # snippet.denotation
)
manager = Manager()
all_tokens = manager.list()

def add_document_worker(doc_title, doc_content, shared_index):
token_freq = compute_document_tokens(doc_content)
shared_index.add_document(title=doc_title, token_freq=token_freq)
return token_freq

shared_index = manager.list()

try:
with multiprocessing.Pool(processes=multiprocessing.cpu_count() // 4) as pool:
results = pool.starmap_async(add_document_worker, [(doc.title, doc.content, shared_index) for doc in all_docs])
pool.close()
pool.join()
# Update the main index and progress after all processes are done
for document_token_freq in results.get():
all_tokens.append(document_token_freq)
if ticket_progress:
ticket_progress.search_progress.indexing_progress += 1
ticket_progress.save()
except Exception as e:
logger.exception(e)
except FileNotFoundError as e:
logger.exception(e)

Expand Down

0 comments on commit 041ee3c

Please sign in to comment.