# arXiv Paper Embedding


## Multi GPU w/ Dask + CUDF
Using Dask and CuDF to orchestrate sentence embedding over multiple GPU workers.

![Rapids and Dask Logos](https://saturn-public-assets.s3.us-east-2.amazonaws.com/example-resources/rapids_dask.png "doc-image")

### Important Imports

* [`dask_saturn`](https://github.com/saturncloud/dask-saturn) and [`dask_distributed`](http://distributed.dask.org/en/stable/): Set up and run the Dask cluster in Saturn Cloud.
* [`dask-cudf`](https://docs.rapids.ai/api/cudf/stable/basics/dask-cudf.html): Create distributed `cudf` dataframes using Dask.

In [1]:
import dask_cudf
import cudf
import json
import os
import re
import string

from dask_saturn import SaturnCluster
from dask.distributed import Client, wait


DATA_PATH = "arxiv-metadata-oai-snapshot.json"
YEAR_CUTOFF = 2022
YEAR_PATTERN = r"(19|20[0-9]{2})"
ML_CATEGORY = "cs.LG"

### Start the Dask Cluster

The template resource you are running has a Dask cluster already attached to it with three workers. The `dask-saturn` code below creates two important objects: a cluster and a client.

* `cluster`: knows about and manages the scheduler and workers
    - can be used to create, resize, reconfigure, or destroy those resources
    - knows how to communicate with the scheduler, and where to find logs and diagnostic dashboards
* `client`: tells the cluster to do things
    - can send work to the cluster
    - can restart all the worker processes
    - can send data to the cluster or pull data back from the cluster

In [2]:
n_workers = 4
cluster = SaturnCluster(n_workers=n_workers)
client = Client(cluster)

INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:Success!


If you already started the Dask cluster on the resource page, then the code above will run much more quickly since it will not have to wait for the cluster to turn on.

>**Pro tip**: Create and start the cluster in the Saturn Cloud UI before opening JupyterLab if you want to get a head start!

The last command ensures the kernel waits until all the desired workers are online before continuing.

In [3]:
client.wait_for_workers(n_workers=n_workers)

In [None]:
def clean_description(description: str):
    if not description:
        return ""
    # remove unicode characters
    description = description.encode('ascii', 'ignore').decode()

    # remove punctuation
    description = re.sub('[%s]' % re.escape(string.punctuation), ' ', description)

    # clean up the spacing
    description = re.sub('\s{2,}', " ", description)

    # remove urls
    #description = re.sub("https*\S+", " ", description)

    # remove newlines
    description = description.replace("\n", " ")

    # remove all numbers
    #description = re.sub('\w*\d+\w*', '', description)

    # split on capitalized words
    description = " ".join(re.split('(?=[A-Z])', description))

    # clean up the spacing again
    description = re.sub('\s{2,}', " ", description)

    # make all words lowercase
    description = description.lower()

    return description

In [None]:
def process(paper: dict):
    paper = json.loads(paper)
    if paper['journal-ref']:
        years = [int(year) for year in re.findall(YEAR_PATTERN, paper['journal-ref'])]
        years = [year for year in years if (year <= 2022 and year >= 1991)]
        year = min(years) if years else None
    else:
        year = None
    return {
        'id': paper['id'],
        'title': paper['title'],
        'year': year,
        'authors': paper['authors'],
        'categories': ','.join(paper['categories'].split(' ')),
        'abstract': paper['abstract'],
        'input': clean_description(paper['title'] + ' ' + paper['abstract'])
    }

def papers():
    with open(DATA_PATH, 'r') as f:
        for paper in f:
            paper = process(paper)
            if paper['year']:
                yield paper


In [None]:
cdf = cudf.DataFrame(list(papers()))

In [None]:
cdf.head()

In [4]:
import pickle
# Pro Tip: Pickle the dataframe
# This might save you time in the future so you don't have to do all of that again
with open('cdf.pkl', 'wb') as f:
    pickle.dump(cdf, f)
    
# Load pickle
# with open('cdf.pkl', 'rb') as f:
#     cdf = pickle.load(f)

In [5]:
# Still going to downsample here
len(cdf)

713361

## Using Dask to parallelize things

In [6]:
# Convert our CuDF to a Dask-CuDF
ddf = dask_cudf.from_cudf(cdf, npartitions=n_workers).persist()

In [7]:
from dask.distributed import get_worker
import numpy as np

def embed_partition(df: dask_cudf.DataFrame):
    """
    Create embeddings on single partition of DF (one dask worker)
    """
    worker = get_worker()
    if hasattr(worker, "model"):
        model = worker.model
    else:
        from sentence_transformers import SentenceTransformer

        model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
        worker.model = model

    print("embedding input", flush=True)
        
    # embed the input      
    vectors = model.encode(
        sentences = df.input.values_host,
        normalize_embeddings = True,
        show_progress_bar = True
    )
    
    # Convert to cudf series and return
    df['vector'] = cudf.Series(vectors.tolist(), index=df.index)
    return df[['id', 'vector']]

def clear_workers():
    """
    Deletes model attribute, freeing up memory on the Dask workers
    """
    import torch
    import gc

    worker = get_worker()
    if hasattr(worker, "model"):
        del worker.model
    torch.cuda.empty_cache()
    gc.collect()
    return

In [8]:
output_df = ddf[["id", "input"]].map_partitions(
    func = embed_partition,
    meta = {
      "id": object,
      "vector": cudf.ListDtype('float32')
    }
)
# Gather results
output_df = output_df.persist()
%time _ = wait(output_df)

CPU times: user 2.02 s, sys: 340 ms, total: 2.36 s
Wall time: 47min 23s


In [9]:
# Check output
len(output_df)

713361

In [10]:
# Check output
output_df.vector.isna().sum().compute()

0

In [31]:
# Merge and then take a sample of all ML papers AND papers older than 2015
full_ddf = ddf.merge(output_df)
full_ddf = full_ddf[(full_ddf.categories.str.contains(ML_CATEGORY)) | (full_ddf.year >= 2015)]

In [32]:
len(full_ddf)

309164

In [33]:
# Store as pickled pandas df

with open('arxiv_embeddings_300000.pkl', 'wb') as f:
    pickle.dump(full_ddf.compute().to_pandas(), f)

In [None]:
# Cleanup dask worker RAM
#client.run(clear_workers)