Skip to content

Commit

Permalink
better inference
Browse files Browse the repository at this point in the history
  • Loading branch information
rom1504 committed Mar 23, 2022
1 parent 40eb333 commit 915c612
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions examples/inference_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import fsspec
import math
import pandas as pd

def load_safety_model():
"""load the safety model"""
Expand Down Expand Up @@ -36,10 +37,20 @@ def load_safety_model():

return loaded_model

import mmh3
def compute_hash(url, text):
if url is None:
url = ''

def main(input_folder, output_folder, batch_size=10**6, end=None):
if text is None:
text = ''

total = (url + text).encode("utf-8")
return mmh3.hash64(total)[0]

def main(embedding_folder, metadata_folder, output_folder, batch_size=10**6, end=None):
"""main function"""
reader = EmbeddingReader(input_folder)
reader = EmbeddingReader(embedding_folder, metadata_folder=metadata_folder, file_format="parquet_npy", meta_columns=["url", "caption"])
fs, relative_output_path = fsspec.core.url_to_fs(output_folder)
fs.mkdirs(relative_output_path, exist_ok=True)

Expand All @@ -49,13 +60,16 @@ def main(input_folder, output_folder, batch_size=10**6, end=None):
batch_count = math.ceil(total // batch_size)
padding = int(math.log10(batch_count)) + 1

for i, (embeddings, ids) in enumerate(reader(batch_size=batch_size, start=0, end=end)):
for i, (embeddings, ids) in enumerate(reader(batch_size=batch_size, start=0, end=end, parallel_pieces=10, max_piece_size=10**4)):
predictions = model.predict(embeddings, batch_size=embeddings.shape[0])
batch = np.hstack(predictions)
padded_id = str(i).zfill(padding)
output_file_path = os.path.join(relative_output_path, padded_id + ".npy")
output_file_path = os.path.join(relative_output_path, padded_id + ".parquet")
df = pd.DataFrame(batch, columns=["prediction"])
df["hash"] = [compute_hash(x, y) for x, y in zip(ids['url'], ids['caption'])]
df["url"] = ids['url']
with fs.open(output_file_path, "wb") as f:
np.save(f, batch)
df.to_parquet(f)


if __name__ == '__main__':
Expand Down

0 comments on commit 915c612

Please sign in to comment.