In [None]:
import multiprocessing
import transformers

import numpy as np
import pandas as pd

from torch import tensor, no_grad
from pyspark.sql import SparkSession

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.functions import predict_batch_udf, array_to_vector

from pyspark.sql.functions import *
from pyspark.sql.types import *

In [2]:
model = transformers.AutoModel.from_pretrained("distilbert-base-uncased")
tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased")

### DistilBert

In [15]:
tokens = tokenizer(
    "Hello World",
    truncation=True,
    padding="max_length",
    max_length=512,
    return_tensors="pt"
)

with no_grad():
    last_hidden_states = model(tokens["input_ids"], attention_mask=tokens["attention_mask"])

features = last_hidden_states[0][:, 0, :]

In [14]:
features[0][:5]

tensor([-0.1698, -0.1662,  0.0256, -0.1442, -0.1771])

### DistilBert + Spark

In [None]:
spark = SparkSession.builder \
        .master("local[*]") \
        .config("spark.driver.memory", "2g") \
        .config("spark.executor.memory", "2g") \
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
        .getOrCreate()

sc = spark.sparkContext

sc

In [None]:
df_raw = spark.read.format("csv") \
              .option("header", "true") \
              .option("inferSchema", "true") \
              .option("delimiter", ",") \
              .load("./archive/Reviews.csv")

df = df_raw.select(
    "Id",
    "Text",
    "Score"
)

df.coalesce(1).write.format("parquet").save("./archive/reviews.parquet")

                                                                                

In [7]:
df = spark.read.format("parquet").load("./archive/reviews.parquet")
df = df.filter(~df.Text.isNull())

                                                                                

In [None]:
df.count()

                                                                                

568444

#### Method 1: Batch UDF

With Batch UDF, python need to fetch variables into memory, this cause OOM.

In [None]:
def get_features():
    model = transformers.AutoModel.from_pretrained("distilbert-base-uncased")
    tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased")

    def predict(inputs: np.ndarray) -> np.ndarray:
        print(f"Batch Input Length: {len(inputs)}")
        
        tokens = [tokenizer.encode(i, truncation=True) for i in inputs]

        max_len = 512
        for i in tokens:
            if len(i) > max_len:
                max_len = len(i)

        padded = np.array([i + [0] * ( max_len - len(i) ) for i in tokens])
        attention_mask = np.where(padded != 0, 1, 0)

        input_ids = tensor(padded)  
        attention_mask = tensor(attention_mask)

        with no_grad():
            last_hidden_states = model(input_ids, attention_mask=attention_mask)

        return last_hidden_states[0][:, 0, :].detach().cpu().numpy()

    return predict    

udf = predict_batch_udf(
    get_features,
    return_type=ArrayType(DoubleType()),
    batch_size=1_000,
)

In [5]:
dff = df.withColumn(
    "features",
    udf("Text")
)

In [None]:
dff.write.format("parquet").mode("overwrite").save("./archive/ev.parquet")

### Method 2: flatMap

This is a optimal solution becausa not get OOM in driver.

I need to apply in each row a flatMap function and convert back to DataFrame and write embeddings into Disk.

In [12]:
tb = sc.broadcast(tokenizer)
mb = sc.broadcast(model)

In [19]:
def get_tokens(inputs):
    model = mb.value
    tokenizer = tb.value

    id = inputs[0]
    text = inputs[1]

    tokens = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )

    with no_grad():
        last_hidden_states = model(tokens["input_ids"], attention_mask=tokens["attention_mask"])

    features = last_hidden_states.last_hidden_state[:, 0, :].detach().cpu().numpy().flatten().tolist()

    return [Row(
        Id=id,
        features=features
    )]

In [20]:
embs_rdd = df.select("Id", "Text").limit(1).rdd.flatMap(get_tokens)

df_embs = embs_rdd.toDF(
    schema=StructType([
        StructField("Id", IntegerType(), False),
        StructField("features", ArrayType(DoubleType()), False)
    ])
)

In [21]:
df_embs.show()

                                                                                

+---+--------------------+
| Id|            features|
+---+--------------------+
|  1|[-0.0356425233185...|
+---+--------------------+



### Method 3: flatMap + Multiprocessing

Now i will try break pyspark DataFrame into chunks and create multi threading for write files into Disk

In [8]:
df = df.limit(50).repartition(4)
df = df.withColumn("id_part", spark_partition_id())

print(df.rdd.getNumPartitions())
df.groupBy("id_part").count().show()

dfs = [
    df.filter(col("id_part") == 0),
    df.filter(col("id_part") == 1),
    df.filter(col("id_part") == 2),
    df.filter(col("id_part") == 3)
]



4


                                                                                

+-------+-----+
|id_part|count|
+-------+-----+
|      0|   12|
|      1|   13|
|      2|   13|
|      3|   12|
+-------+-----+



In [None]:
spark.createDataFrame(
    sc.parallelize([]),
    schema=StructType([
        StructField("Id", IntegerType(), True),
        StructField("features", ArrayType(DoubleType()), True)
    ])
).write.format("delta") \
       .mode("overwrite") \
       .option("overwriteSchema", "true") \
       .save("./archive/ev")

In [9]:
def get_tokens(inputs):
    global model, tokenizer

    if model is None or tokenizer is None:
        model = mb.value
        tokenizer = tb.value

    id = inputs[0]
    text = inputs[1]

    tokens = [tokenizer.encode(text, truncation=True)]

    max_len = 512
    for i in tokens:
        if len(i) > max_len:
            max_len = len(i)

    padded = np.array([i + [0] * ( max_len - len(i) ) for i in tokens])
    attention_mask = np.where(padded != 0, 1, 0)

    input_ids = tensor(padded)  
    attention_mask = tensor(attention_mask)

    with no_grad():
        last_hidden_states = model(input_ids, attention_mask=attention_mask)

    features = last_hidden_states.last_hidden_state[:, 0, :].detach().cpu().numpy().flatten().tolist()

    return [Row(
        Id=id,
        features=features
    )]

def write_delta_batch(df):
    partition = df.select("id_part").limit(1).collect()[0][0]

    print(f"START Partition: {partition}")
    embs_rdd = df.select("Id", "Text").rdd.flatMap(get_tokens)

    df_embs = embs_rdd.toDF(
        schema=StructType([
            StructField("Id", IntegerType(), False),
            StructField("features", ArrayType(DoubleType()), False)
        ])
    )

    df_embs.write.format("delta").mode("append").save("./archive/ev")

    print(f"END Partition: {partition}")

In [10]:
pool = multiprocessing.pool.ThreadPool(4)

pool.map(write_delta_batch, dfs)

START Partition: 0
START Partition: 2
START Partition: 3
START Partition: 1


25/04/23 14:19:49 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

END Partition: 1
END Partition: 0
END Partition: 2
END Partition: 3


[None, None, None, None]

In [11]:
spark.read.format("delta").load("./archive/ev").show()

+---+--------------------+
| Id|            features|
+---+--------------------+
|  2|[-0.0753862112760...|
| 25|[0.09975848346948...|
| 11|[0.12347056716680...|
|  3|[-0.1324467360973...|
| 50|[0.05504393950104...|
| 32|[-0.1600995659828...|
| 33|[-0.2392839342355...|
|  8|[-0.0289996191859...|
| 22|[0.13547386229038...|
| 24|[0.04533419385552...|
|  6|[-0.0724526047706...|
| 40|[-0.0022358724381...|
| 10|[0.01900614425539...|
|  4|[-0.0978876128792...|
| 38|[-0.1678747832775...|
| 43|[0.02412325888872...|
| 39|[-0.0451349355280...|
| 31|[-0.1412217020988...|
| 14|[0.09105278551578...|
| 19|[-0.2863656878471...|
+---+--------------------+
only showing top 20 rows

