In [14]:
import time
import sparknlp
import pandas as pd
from pyspark.sql.window import Window
import pyspark.sql.functions as F
from sparknlp.pretrained import PretrainedPipeline
path = "data/data_set_fusion.csv"

In [2]:
spark = sparknlp.start()
pipeline = PretrainedPipeline.from_disk('entity_recognizer_lg_fr')

# Breakdown pipeline transform

In [3]:
def run_unknown(column):
    df_spark = spark.read.option("header","true").csv(path)
    data = df_spark.select(column).toDF("text")
    annotations = pipeline.transform(data)
    
    # collect() is an action, it run selectExpr and transform also
    start = time.time()    
    list_ner = annotations.selectExpr("ner.result AS ner").collect()
    stop = time.time()
    print("collect time: %.2f" % (stop - start), "s")
    
    return list_ner

In [4]:
#%%time
#list_ner_carte = run_unknown("type_de_carte")

In [5]:
#%%time
#list_ner_adresse = run_unknown("adresse")

In [6]:
#%%time
#list_ner_nom = run_unknown("nom")

# Breakdown get entity

In [7]:
def get_entity(list_ner):
    result = []
    for ner in list_ner:
        count_entity = {'PER': 0, 'LOC': 0, 'MISC': 0, 'ORG': 0}
        for i in range(len(ner.ner)):
            for key in count_entity.keys():
                if key in ner.ner[i]:
                    count_entity[key] += 1
        max_key = max(count_entity, key = count_entity.get)
        if count_entity[max_key]/len(ner.ner) < 2/3:
            max_key = "UNKNOWN"
        result.append(max_key)
    return max(result, key = result.count)

In [8]:
#%%time
#get_entity(list_ner_carte)

In [9]:
#%%time
#get_entity(list_ner_adresse)

In [10]:
#%%time
#get_entity(list_ner_nom)

# Agg spark version

In [11]:
def get_entity_agg(column):
    df_spark = spark.read.option("header","true").csv(path)
    data = df_spark.select(column).toDF("text")
    annotations = pipeline.transform(data)
    result = annotations.select(F.col("text"), F.explode("ner.result").alias("entity"))

    result = result.withColumn('ent', F.when(result['entity'] != "O", F.split(result['entity'], '-').getItem(1)).otherwise(result['entity'])) \
                    .drop("entity") \
                    .groupby("text", "ent") \
                    .agg(F.count("ent").alias("count"))

    w2 = Window.partitionBy("text").orderBy(F.col("count").desc())

    result = result.withColumn("row",F.row_number().over(w2)) \
                    .filter(F.col("row") == 1) \
                    .drop("row").orderBy("text") \
                    .groupby("ent").count(). \
                    orderBy(F.col("count").desc())
    
    return result.first()[0]

# Total time

In [18]:
%%time
print("NORMAL VERSION: ")
list_ner_nom = run_unknown("nom")
get_entity(list_ner_nom)

NORMAL VERSION: 
collect time: 49.32 s
CPU times: user 8.93 s, sys: 201 ms, total: 9.13 s
Wall time: 59.2 s


'PER'

In [19]:
%%time
print("AGG VERSION: ")
get_entity_agg("nom")

AGG VERSION: 
CPU times: user 39.7 ms, sys: 16.5 ms, total: 56.2 ms
Wall time: 54.6 s


'PER'

In [20]:
%%time
print("NORMAL VERSION: ")
list_ner_nom = run_unknown("adresse")
get_entity(list_ner_nom)

NORMAL VERSION: 
collect time: 39.85 s
CPU times: user 13.9 s, sys: 177 ms, total: 14.1 s
Wall time: 54.1 s


'LOC'

In [21]:
%%time
print("AGG VERSION: ")
get_entity_agg("adresse")

AGG VERSION: 
CPU times: user 37.8 ms, sys: 15.9 ms, total: 53.7 ms
Wall time: 50.1 s


'LOC'

In [1]:
spark.stop()

NameError: name 'spark' is not defined