Skip to content

Commit

Permalink
Update Mapping with Semaphore (#783)
Browse files Browse the repository at this point in the history
  • Loading branch information
dongreenberg committed May 7, 2024
1 parent 18a810d commit c5d478a
Showing 1 changed file with 61 additions and 78 deletions.
139 changes: 61 additions & 78 deletions examples/parallel-hf-embedding-ec2/parallel_hf_embedding_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,24 @@
import requests

import runhouse as rh
import torch
from bs4 import BeautifulSoup
from tqdm.asyncio import tqdm

# Then, we define some utility functions that we'll use for our embedding task.
def _extract_urls_helper(url, visited, original_url, max_depth=1, current_depth=1):
"""
Extracts all URLs from a given URL, recursively up to a maximum depth.
"""
if url in visited:
if (
url in visited
or urlparse(url).netloc != urlparse(original_url).netloc
or "redirect" in url
):
return []

visited.add(url)

if urlparse(url).netloc != urlparse(original_url).netloc:
return []

if "redirect" in url:
return []

urls = [url]

if current_depth <= max_depth:
Expand Down Expand Up @@ -101,52 +101,26 @@ def extract_urls(url, max_depth=1):
#
# Learn more in the [Runhouse docs on functions and modules](/docs/tutorials/api-modules).
class URLEmbedder:
def __init__(self, gpu_number: int):
self.model = None
self.vectorstore = None
self.gpu_number = gpu_number

def initialize_model(self):
if self.model is None:
from langchain.embeddings import HuggingFaceBgeEmbeddings

model_name = "BAAI/bge-large-en-v1.5"
model_kwargs = {"device": f"cuda:{self.gpu_number}"}
encode_kwargs = {
"normalize_embeddings": True
} # set True to compute cosine similarity

self.model = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
def __init__(self, **model_kwargs):
from sentence_transformers import SentenceTransformer

def embed_docs(self, urls: List[str]):
self.model = torch.compile(SentenceTransformer(**model_kwargs))

def embed_docs(self, urls: List[str], **embed_kwargs):
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

self.initialize_model()

# Occasionally, the loader will fail to load the URLs, so we catch the exception and return None.
loader = WebBaseLoader(
start = time.time()
docs = WebBaseLoader(
web_paths=urls,
)
docs = loader.load()

text_splitter = RecursiveCharacterTextSplitter(
).load()
splits = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200
)
splits = text_splitter.split_documents(docs)
).split_documents(docs)
splits_as_str = [doc.page_content for doc in splits]

# Time the actual embedding
start_time = time.time()
embeddings = self.model.embed_documents(splits_as_str)
print(
f"Time to embed {len(splits_as_str)} text chunks: {time.time() - start_time}"
)
return embeddings
downloaded = time.time()
embedding = self.model.encode(splits_as_str, **embed_kwargs)
return urls[0], embedding, downloaded - start, time.time() - downloaded


# ## Setting up Runhouse primitives
Expand All @@ -164,11 +138,9 @@ def embed_docs(self, urls: List[str]):
# the script code will run when Runhouse attempts to import your code remotely.
# :::
async def main():
cluster = rh.cluster("rh-a10g", instance_type="A10G:4").save().up_if_not()

# We set up some parameters for our embedding task.
num_replicas = 4 # Number of models to load side by side
num_parallel_tasks = 128 # Number of parallel calls to make to the replicas
max_concurrency_per_replica = 32 # Number of parallel calls to make to each replica
url_to_recursively_embed = "https://en.wikipedia.org/wiki/Poker"

# We recursively extract all children URLs from the given URL.
Expand All @@ -186,28 +158,33 @@ async def main():
# returned by `get_or_to` functions exactly the same as a local instance of the module, but when we call a
# function (like `initialize_model`) on it, the function is run on the remote machine.
start_time = time.time()
replicas = []
embedder_replicas = []
cluster = rh.cluster(
f"rh-{num_replicas}xa10g",
instance_type="A10G:1",
num_instances=num_replicas,
spot=True,
).up_if_not()
for i in range(num_replicas):
env = rh.env(
name=f"langchain_embed_env_{i}",
reqs=[
"langchain",
"langchain-community",
"langchainhub",
"lancedb",
"bs4",
"sentence_transformers",
"fake_useragent",
],
secrets=["huggingface"],
compute={"GPU": 1},
)
local_url_embedder_module = rh.module(URLEmbedder, name=f"doc_embedder_{i}")(
gpu_number=i
RemoteURLEmbedder = rh.module(URLEmbedder).to(cluster, env)
remote_url_embedder = RemoteURLEmbedder(
model_name_or_path="BAAI/bge-large-en-v1.5",
device="cuda",
name=f"doc_embedder_{i}",
)
remote_url_embedder_module = local_url_embedder_module.get_or_to(
system=cluster, env=env
)
remote_url_embedder_module.initialize_model()
replicas.append(remote_url_embedder_module)
embedder_replicas.append(remote_url_embedder)
print(f"Time to initialize {num_replicas} replicas: {time.time() - start_time}")

# ## Calling the Runhouse modules in parallel
Expand All @@ -216,27 +193,33 @@ async def main():
# `asyncio` library to make parallel calls, we need to use a special `run_async=True` argument to the
# Runhouse function. This tells Runhouse to return a coroutine that we can await on, rather than making
# a blocking network call to the server.
semaphore = asyncio.Semaphore(max_concurrency_per_replica * num_replicas)

async def load_and_embed(url, idx):
async with semaphore:
print(f"Embedding {url} on replica {idx % num_replicas}")
embedder_replica = embedder_replicas[idx % num_replicas]
return await embedder_replica.embed_docs(
[url], normalize_embeddings=True, run_async=True, stream_logs=False
)

start_time = time.time()
futs = [
asyncio.create_task(
replicas[i % num_replicas].embed_docs([urls[i]], run_async=True)
)
for i in range(len(urls))
]

all_embeddings = []
failures = 0
task_results = await asyncio.gather(*futs)
for res in task_results:
if res is not None:
all_embeddings.extend(res)
else:
print("An embedding call failed.")
failures += 1

print(f"Received {len(all_embeddings)} total embeddings, with {failures} failures.")
futs = [load_and_embed(url, idx) for idx, url in enumerate(urls)]
task_results = await tqdm.gather(*futs)

failures = len([res for res in task_results if isinstance(res, Exception)])
total_download_time = sum(
[res[2] for res in task_results if not isinstance(res, Exception)]
)
total_embed_time = sum(
[res[3] for res in task_results if not isinstance(res, Exception)]
)
print(
f"Embedded {len(urls)} docs across {num_replicas} replicas with {num_parallel_tasks} concurrent calls: {time.time() - start_time}"
f"Received {len(task_results) - failures} total embeddings, with {failures} failures.\n"
f"Embedded {len(urls)} docs across {num_replicas} replicas with {max_concurrency_per_replica} "
f"concurrent calls: {time.time() - start_time} \n"
f"Total sys time for downloads: {total_download_time} \n"
f"Total sys time for embeddings: {total_embed_time}"
)


Expand Down

0 comments on commit c5d478a

Please sign in to comment.