# PySpark Huggingface Inferencing
## Conditional generation

From: https://huggingface.co/docs/transformers/model_doc/t5

### Using PyTorch

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

max_source_length = 512
max_target_length = 128

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [2]:
input_ids = tokenizer(input_sequences, 
                      padding="longest", 
                      max_length=max_source_length,
                      truncation=True,
                      return_tensors="pt").input_ids

outputs = model.generate(input_ids)



In [3]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

['Das Haus ist wunderbar',
 'Willkommen in NYC',
 'HuggingFace ist ein Unternehmen']

In [4]:
model.framework

'pt'

### Using TensorFlow

In [5]:
from transformers import AutoTokenizer, TFT5ForConditionalGeneration

2024-09-18 18:04:10.354829: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-18 18:04:10.354873: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-18 18:04:10.356281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-18 18:04:10.363572: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [6]:
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

max_source_length = 512
max_target_length = 128

task_prefix = "translate English to German: "

lines = [
    "The house is wonderful",
    "Welcome to NYC",
    "HuggingFace is a company"
]

input_sequences = [task_prefix + l for l in lines]

2024-09-18 18:04:17.110820: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30765 MB memory:  -> device: 0, name: Tesla V100-SXM3-32GB-H, pci bus id: 0000:34:00.0, compute capability: 7.0
2024-09-18 18:04:17.112209: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31135 MB memory:  -> device: 1, name: Tesla V100-SXM3-32GB-H, pci bus id: 0000:36:00.0, compute capability: 7.0
2024-09-18 18:04:17.113375: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 31135 MB memory:  -> device: 2, name: Tesla V100-SXM3-32GB-H, pci bus id: 0000:39:00.0, compute capability: 7.0
2024-09-18 18:04:17.114536: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 31135 MB memory:  -> device: 3, name: Tesla V100-SXM3-32GB-H, pc

In [7]:
# Tensorflow GPUS and version:

import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))
print(tf.__version__)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:4', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:5', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:6', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:7', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:8', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:9', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:10', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:11', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:12', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:13', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:14', device_type='GPU'), Phys

In [8]:
#import os
#os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/lib/cuda" # set to location of cuda/nvvm/libdevice

input_ids = tokenizer(input_sequences, 
                      padding="longest", 
                      max_length=max_source_length,
                      return_tensors="tf").input_ids
outputs = model.generate(input_ids)

2024-09-18 18:04:20.353882: I external/local_xla/xla/service/service.cc:168] XLA service 0x280e86d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-09-18 18:04:20.353910: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
2024-09-18 18:04:20.353917: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (1): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
2024-09-18 18:04:20.353925: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (2): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
2024-09-18 18:04:20.353931: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (3): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
2024-09-18 18:04:20.353938: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (4): Tesla V100-SXM3-32GB-H, Compute Capability 7.0
2024-09-18 18:04:20.353944: I external/local_xl

In [9]:
[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]

['Das Haus ist wunderbar',
 'Willkommen in NYC',
 'HuggingFace ist ein Unternehmen']

In [10]:
model.framework

'tf'

## PySpark

In [11]:
import os
from pathlib import Path
from datasets import load_dataset

In [14]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession

In [15]:
num_threads = 6

# Creating a local Spark session for demonstration, in case it hasn't already been created.

_config = {
    "spark.master": f"local[{num_threads}]",
    "spark.driver.host": "127.0.0.1",
    "spark.task.maxFailures": "1",
    "spark.driver.memory": "8g",
    "spark.executor.memory": "8g",
    "spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": "false",
    "spark.sql.pyspark.jvmStacktrace.enabled": "true",
    "spark.sql.execution.arrow.pyspark.enabled": "true",
    "spark.python.worker.reuse": "true",
}
spark = SparkSession.builder.appName("spark-dl-example")
for key, value in _config.items():
    spark = spark.config(key, value)
spark = spark.getOrCreate()

sc = spark.sparkContext

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
24/09/18 18:04:26 WARN Utils: Your hostname, dgx2h0194.spark.sjc4.nvmetal.net resolves to a loopback address: 127.0.1.1; using 10.150.30.2 instead (on interface enp134s0f0np0)
24/09/18 18:04:26 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/18 18:04:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [12]:
# load IMDB reviews (test) dataset
data = load_dataset("imdb", split="test")

In [13]:
lines = []
for example in data:
    lines.append([example["text"].split(".")[0]])

len(lines)

25000

### Create PySpark DataFrame

In [16]:
df = spark.createDataFrame(lines, ['lines']).repartition(10)
df.schema

StructType([StructField('lines', StringType(), True)])

In [17]:
df.take(1)

24/09/18 18:04:30 WARN TaskSetManager: Stage 0 contains a task of very large size (5123 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

[Row(lines='EVAN ALMIGHTY (2007) ** Steve Carell, Morgan Freeman, Lauren Graham, Johnny Simmons, Graham Phillips, Jimmy Bennett, John Goodman, Wanda Sykes, John Michael Higgins, Jonah Hill, Molly Shannon, Ed Helms, (Cameo: Jon Stewart as himself) Strained \'sequel\' to "BRUCE ALMIGHTY" with Carell\'s jerk anchorman Evan Baxter leaving TV to begin his stint as a freshman Congressional rep has his hands full when God (Freeman reprising his holy role; Jim Carrey wisely avoided the \'calling\') demands he build an ark like Noah and the hilarity ensues (or should have). The Godforsaken sitcom-y script by Steve Oedekerk, Joel Cohen & Alec Sokolow is absolutely lame and only Carell\'s amiable persona transcends his vain Evan into something resembling a human being. The end result is a lot of bird poop gags and overall bloat (reportedly costing $175 M for the CGI F/X). Sykes steals the show as Evan\'s sarcastic assistant. Sacrilegiously unfunny. (Dir: Tom Shadyac)')]

### Save the test dataset as parquet files

In [18]:
df.write.mode("overwrite").parquet("imdb_test")

24/09/18 18:04:32 WARN TaskSetManager: Stage 3 contains a task of very large size (5123 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

### Check arrow memory configuration

In [19]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "512")
# This line will fail if the vectorized reader runs out of memory
assert len(df.head()) > 0, "`df` should not be empty"

24/09/18 18:04:33 WARN TaskSetManager: Stage 6 contains a task of very large size (5123 KiB). The maximum recommended task size is 1000 KiB.


## Inference using Spark DL API (PyTorch)
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [20]:
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [21]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [22]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100)
df.show(truncate=120)
df.count()

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   lines|
+------------------------------------------------------------------------------------------------------------------------+
|In following Dylan Moran's star from the charming misanthrope bookstore owner in the surrealist sitcom Black Books, I...|
|Here in Australia Nights in Rodanthe is being promoted in the same class as the Notebook. Quite frankly what a lot of...|
|The Tender Hook, or, Who Killed The Australian Film Industry? Case No. 278. This sorry excuse for a period drama take...|
|The only reason I'm even giving this movie a 4 is because it was made in to an episode of Mystery Science Theater 300...|
|Hooray for Title Misspellings! After reading reviews and contemplating, my girlfriend and I confirmed that this movie...|
|This movie make

100

In [23]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100).cache()

In [24]:
df1.count()

                                                                                

100

In [25]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|Translate English to German: In following Dylan Moran's star from the charming misanthrope bookstore owner in the sur...|
|   Translate English to German: Here in Australia Nights in Rodanthe is being promoted in the same class as the Notebook|
|                      Translate English to German: The Tender Hook, or, Who Killed The Australian Film Industry? Case No|
|Translate English to German: The only reason I'm even giving this movie a 4 is because it was made in to an episode o...|
|Translate English to German: Hooray for Title Misspellings! After reading reviews and contemplating, my girlfriend an...|
|               

In [26]:
def predict_batch_fn():
    import numpy as np
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    tokenizer = T5Tokenizer.from_pretrained("t5-small")

    def predict(inputs):
        flattened = np.squeeze(inputs).tolist()   # convert 2d numpy array of string into flattened python list
        input_ids = tokenizer(flattened, 
                              padding="longest", 
                              max_length=128,
                              return_tensors="pt").input_ids
        output_ids = model.generate(input_ids)
        string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in output_ids])
        print("predict: {}".format(len(flattened)))
        return string_outputs
    
    return predict

In [27]:
generate = predict_batch_udf(predict_batch_fn,
                             return_type=StringType(),
                             batch_size=10)

In [28]:
%%time
# first pass caches model/fn
preds = df1.withColumn("preds", generate(struct("input")))
results = preds.collect()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 80.1 ms, sys: 63 ms, total: 143 ms
Wall time: 18.9 s


predict: 10
                                                                                

In [29]:
%%time
preds = df1.withColumn("preds", generate("input"))
results = preds.collect()

predict: 10                                                         (0 + 1) / 1]
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 59.1 ms, sys: 62.9 ms, total: 122 ms
Wall time: 15.6 s


predict: 10
                                                                                

In [30]:
%%time
preds = df1.withColumn("preds", generate(col("input")))
results = preds.collect()

predict: 10                                                         (0 + 1) / 1]
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 69.1 ms, sys: 49.2 ms, total: 118 ms
Wall time: 15.5 s


predict: 10
                                                                                

In [31]:
preds.show(truncate=60)

predict: 10                                                         (0 + 1) / 1]
predict: 10


+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to German: In following Dylan Moran's s...|Indem ich Dylan Morans Star aus dem charmanten Buchhalter...|
|Translate English to German: Here in Australia Nights in ...|Hier in Australien Nights in Rodanthe wird in der gleiche...|
|Translate English to German: The Tender Hook, or, Who Kil...|The Tender Hook, or, Who Killed The Australian Film Indus...|
|Translate English to German: The only reason I'm even giv...|Der einzige Grund, warum ich diesen Film sogar eine 4 ver...|
|Translate English to German: Hooray for Title Misspelling...|Nach dem Lesen von Rezensionen und der Überlegung bestäti...|
|Transla

predict: 204
                                                                                

In [32]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100).cache()

In [33]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|Translate English to French: In following Dylan Moran's star from the charming misanthrope bookstore owner in the sur...|
|   Translate English to French: Here in Australia Nights in Rodanthe is being promoted in the same class as the Notebook|
|                      Translate English to French: The Tender Hook, or, Who Killed The Australian Film Industry? Case No|
|Translate English to French: The only reason I'm even giving this movie a 4 is because it was made in to an episode o...|
|Translate English to French: Hooray for Title Misspellings! After reading reviews and contemplating, my girlfriend an...|
|               

In [34]:
%%time
# first pass caches model/fn
preds = df2.withColumn("preds", generate(struct("input")))
result = preds.collect()

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 76.7 ms, sys: 65 ms, total: 142 ms
Wall time: 19.1 s


predict: 10
                                                                                

In [35]:
%%time
preds = df2.withColumn("preds", generate("input"))
result = preds.collect()

predict: 10                                                         (0 + 1) / 1]
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 54.4 ms, sys: 61.3 ms, total: 116 ms
Wall time: 15.4 s


predict: 10
                                                                                

In [36]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
result = preds.collect()

predict: 10                                                         (0 + 1) / 1]
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10
predict: 10


CPU times: user 44.4 ms, sys: 71.7 ms, total: 116 ms
Wall time: 15.5 s


predict: 10
                                                                                

In [37]:
preds.show(truncate=60)

predict: 10                                                         (0 + 1) / 1]
predict: 10


+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to French: In following Dylan Moran's s...|       En suivant l'étoile de Dylan Moran, tirée du charmant|
|Translate English to French: Here in Australia Nights in ...|              Ici, en Australie Nights à Rodanthe, est promu|
|Translate English to French: The Tender Hook, or, Who Kil...|The Tender Hook, or, Who Killed The Australian Film Indus...|
|Translate English to French: The only reason I'm even giv...|La seule raison pour laquelle je donne même ce film un 4 ...|
|Translate English to French: Hooray for Title Misspelling...|Après avoir lu les critiques et réfléchi à la question, m...|
|Transla

predict: 204
                                                                                

### Using Triton Inference Server

Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments), using a conda-pack environment created as follows:
```
conda create -n huggingface -c conda-forge python=3.8
conda activate huggingface

export PYTHONNOUSERSITE=True
pip install conda-pack sentencepiece sentence_transformers transformers

conda-pack  # huggingface.tar.gz
```

In [38]:
import os

In [39]:
%%bash
# copy custom model to expected layout for Triton
rm -rf models
mkdir -p models
cp -r models_config/hf_generation models

# add custom execution environment
cp huggingface.tar.gz models

#### Start Triton Server on each executor

In [40]:
num_executors = 1
triton_models_dir = "{}/models".format(os.getcwd())
huggingface_cache_dir = "{}/.cache/huggingface".format(os.path.expanduser('~'))
nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)

def start_triton(it):
    import docker
    import time
    import tritonclient.grpc as grpcclient
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    if containers:
        print(">>>> containers: {}".format([c.short_id for c in containers]))
    else:
        container=client.containers.run(
            "nvcr.io/nvidia/tritonserver:23.04-py3", "tritonserver --model-repository=/models",
            detach=True,
            device_requests=[docker.types.DeviceRequest(device_ids=["0"], capabilities=[['gpu']])],
            environment=[
                "TRANSFORMERS_CACHE=/cache"
            ],
            name="spark-triton",
            network_mode="host",
            remove=True,
            shm_size="1G",
            volumes={
                triton_models_dir: {"bind": "/models", "mode": "ro"},
                huggingface_cache_dir: {"bind": "/cache", "mode": "rw"}
            }
        )
        print(">>>> starting triton: {}".format(container.short_id))

        # wait for triton to be running
        time.sleep(15)
        client = grpcclient.InferenceServerClient("localhost:8001")
        ready = False
        while not ready:
            try:
                ready = client.is_server_ready()
            except Exception as e:
                time.sleep(5)

    return [True]

nodeRDD.barrier().mapPartitions(start_triton).collect()

>>>> containers: ['14361b461f08']


[True]

#### Run inference

In [41]:
import pandas as pd
from functools import partial
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.functions import col, pandas_udf, struct
from pyspark.sql.types import StringType

In [42]:
# only use first N examples, since this is slow
df = spark.read.parquet("imdb_test").limit(100).cache()

In [43]:
# only use first sentence and add prefix for conditional generation
def preprocess(text: pd.Series, prefix: str = "") -> pd.Series:
    @pandas_udf("string")
    def _preprocess(text: pd.Series) -> pd.Series:
        return pd.Series([prefix + s.split(".")[0] for s in text])
    return _preprocess(text)

In [44]:
# only use first 100 rows, since generation takes a while
df1 = df.withColumn("input", preprocess(col("lines"), "Translate English to German: ")).select("input").limit(100)

In [45]:
df1.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|Translate English to German: In following Dylan Moran's star from the charming misanthrope bookstore owner in the sur...|
|   Translate English to German: Here in Australia Nights in Rodanthe is being promoted in the same class as the Notebook|
|                      Translate English to German: The Tender Hook, or, Who Killed The Australian Film Industry? Case No|
|Translate English to German: The only reason I'm even giving this movie a 4 is because it was made in to an episode o...|
|Translate English to German: Hooray for Title Misspellings! After reading reviews and contemplating, my girlfriend an...|
|               

In [46]:
def triton_fn(triton_uri, model_name):
    import numpy as np
    import tritonclient.grpc as grpcclient
    
    np_types = {
      "BOOL": np.dtype(np.bool8),
      "INT8": np.dtype(np.int8),
      "INT16": np.dtype(np.int16),
      "INT32": np.dtype(np.int32),
      "INT64": np.dtype(np.int64),
      "FP16": np.dtype(np.float16),
      "FP32": np.dtype(np.float32),
      "FP64": np.dtype(np.float64),
      "FP64": np.dtype(np.double),
      "BYTES": np.dtype(object)
    }

    client = grpcclient.InferenceServerClient(triton_uri)
    model_meta = client.get_model_metadata(model_name)
    
    def predict(inputs):
        if isinstance(inputs, np.ndarray):
            # single ndarray input
            request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]
            request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))
        else:
            # dict of multiple ndarray inputs
            request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]
            for i in request:
                i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))
        
        response = client.infer(model_name, inputs=request)
        
        if len(model_meta.outputs) > 1:
            # return dictionary of numpy arrays
            return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}
        else:
            # return single numpy array
            return response.as_numpy(model_meta.outputs[0].name)
        
    return predict

In [47]:
generate = predict_batch_udf(partial(triton_fn, triton_uri="localhost:8001", model_name="hf_generation"),
                             return_type=StringType(),
                             input_tensor_shapes=[[1]],
                             batch_size=100)

In [48]:
%%time
# first pass caches model/fn
preds = df1.withColumn("preds", generate(struct("input")))
results = preds.collect()



CPU times: user 15.6 ms, sys: 31.6 ms, total: 47.2 ms
Wall time: 5.13 s


                                                                                

In [49]:
%%time
preds = df1.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 47:>                                                         (0 + 1) / 1]

CPU times: user 20 ms, sys: 22 ms, total: 41.9 ms
Wall time: 4.52 s


                                                                                

In [50]:
%%time
preds = df1.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 49:>                                                         (0 + 1) / 1]

CPU times: user 8.11 ms, sys: 34.1 ms, total: 42.3 ms
Wall time: 4.52 s


                                                                                

In [51]:
preds.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to German: In following Dylan Moran's s...|Indem ich Dylan Morans Star aus dem charmanten Buchhalter...|
|Translate English to German: Here in Australia Nights in ...|Hier in Australien Nights in Rodanthe wird in der gleiche...|
|Translate English to German: The Tender Hook, or, Who Kil...|The Tender Hook, or, Who Killed The Australian Film Indus...|
|Translate English to German: The only reason I'm even giv...|Der einzige Grund, warum ich diesen Film sogar eine 4 ver...|
|Translate English to German: Hooray for Title Misspelling...|Nach dem Lesen von Rezensionen und der Überlegung bestäti...|
|Transla

                                                                                

In [52]:
# only use first 100 rows, since generation takes a while
df2 = df.withColumn("input", preprocess(col("lines"), "Translate English to French: ")).select("input").limit(100).cache()

24/09/18 18:06:42 WARN CacheManager: Asked to cache already cached data.


In [53]:
df2.show(truncate=120)

+------------------------------------------------------------------------------------------------------------------------+
|                                                                                                                   input|
+------------------------------------------------------------------------------------------------------------------------+
|Translate English to French: In following Dylan Moran's star from the charming misanthrope bookstore owner in the sur...|
|   Translate English to French: Here in Australia Nights in Rodanthe is being promoted in the same class as the Notebook|
|                      Translate English to French: The Tender Hook, or, Who Killed The Australian Film Industry? Case No|
|Translate English to French: The only reason I'm even giving this movie a 4 is because it was made in to an episode o...|
|Translate English to French: Hooray for Title Misspellings! After reading reviews and contemplating, my girlfriend an...|
|               

In [54]:
%%time
preds = df2.withColumn("preds", generate(struct("input")))
results = preds.collect()



CPU times: user 18.4 ms, sys: 28.9 ms, total: 47.3 ms
Wall time: 5.19 s


                                                                                

In [55]:
%%time
preds = df2.withColumn("preds", generate("input"))
results = preds.collect()

[Stage 57:>                                                         (0 + 1) / 1]

CPU times: user 18.2 ms, sys: 22.8 ms, total: 41 ms
Wall time: 4.5 s


                                                                                

In [56]:
%%time
preds = df2.withColumn("preds", generate(col("input")))
results = preds.collect()

[Stage 59:>                                                         (0 + 1) / 1]

CPU times: user 19.9 ms, sys: 19.5 ms, total: 39.4 ms
Wall time: 4.3 s


                                                                                

In [57]:
preds.show(truncate=60)

+------------------------------------------------------------+------------------------------------------------------------+
|                                                       input|                                                       preds|
+------------------------------------------------------------+------------------------------------------------------------+
|Translate English to French: In following Dylan Moran's s...|       En suivant l'étoile de Dylan Moran, tirée du charmant|
|Translate English to French: Here in Australia Nights in ...|              Ici, en Australie Nights à Rodanthe, est promu|
|Translate English to French: The Tender Hook, or, Who Kil...|The Tender Hook, or, Who Killed The Australian Film Indus...|
|Translate English to French: The only reason I'm even giv...|La seule raison pour laquelle je donne même ce film un 4 ...|
|Translate English to French: Hooray for Title Misspelling...|Après avoir lu les critiques et réfléchi à la question, m...|
|Transla

#### Stop Triton Server on each executor

In [58]:
def stop_triton(it):
    import docker
    import time
    
    client=docker.from_env()
    containers=client.containers.list(filters={"name": "spark-triton"})
    print(">>>> stopping containers: {}".format([c.short_id for c in containers]))
    if containers:
        container=containers[0]
        container.stop(timeout=120)

    return [True]

nodeRDD.barrier().mapPartitions(stop_triton).collect()

>>>> stopping containers: ['14361b461f08']
                                                                                

[True]

In [59]:
spark.stop()