<a href="https://colab.research.google.com/github/ryancburke/AG_news/blob/main/bert_sentence_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import os

# Install java
! apt-get update -qq
! apt-get install -y openjdk-8-jdk-headless -qq > /dev/null

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]
! java -version

# Install pyspark
! pip install --ignore-installed -q pyspark==2.4.4
! pip install --ignore-installed -q spark-nlp==2.7.1

openjdk version "1.8.0_282"
OpenJDK Runtime Environment (build 1.8.0_282-8u282-b08-0ubuntu1~18.04-b08)
OpenJDK 64-Bit Server VM (build 25.282-b08, mixed mode)


In [6]:
import sparknlp

spark = sparknlp.start(gpu = True) # for GPU training >> sparknlp.start(gpu = True) # for Spark 2.3 =>> sparknlp.start(spark23 = True)

from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline
import pandas as pd

print("Spark NLP version", sparknlp.version())

print("Apache Spark version:", spark.version)

spark

Spark NLP version 2.7.1
Apache Spark version: 2.4.4


In [16]:
! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_train.csv
! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_test.csv


--2021-03-24 15:45:59--  https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_train.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24032125 (23M) [text/plain]
Saving to: ‘news_category_train.csv.4’


2021-03-24 15:46:00 (90.7 MB/s) - ‘news_category_train.csv.4’ saved [24032125/24032125]

--2021-03-24 15:46:00--  https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_test.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HT

In [17]:
df = spark.read \
      .option("header", True) \
      .csv("news_category_train.csv.4")

df.show(truncate=50)

+--------+--------------------------------------------------+
|category|                                       description|
+--------+--------------------------------------------------+
|Business| Short sellers, Wall Street's dwindling band of...|
|Business| Private investment firm Carlyle Group, which h...|
|Business| Soaring crude prices plus worries about the ec...|
|Business| Authorities have halted oil export flows from ...|
|Business| Tearaway world oil prices, toppling records an...|
|Business| Stocks ended slightly higher on Friday but sta...|
|Business| Assets of the nation's retail money market mut...|
|Business| Retail sales bounced back a bit in July, and n...|
|Business|" After earning a PH.D. in Sociology, Danny Baz...|
|Business| Short sellers, Wall Street's dwindling  band o...|
|Business| Soaring crude prices plus worries  about the e...|
|Business| OPEC can do nothing to douse scorching  oil pr...|
|Business| Non OPEC oil exporters should consider  increa...|
|Busines

In [19]:
from pyspark.sql.functions import col
df.groupBy("category") \
    .count() \
    .orderBy(col("count").desc()) \
    .show()

+--------+-----+
|category|count|
+--------+-----+
|   World|30000|
|  Sports|30000|
|Sci/Tech|30000|
|Business|30000|
+--------+-----+



In [20]:
(train_df, val_df) = df.randomSplit([0.7, 0.3], seed = 8)
print("Training Dataset Count: " + str(train_df.count()))
print("Validation Dataset Count: " + str(val_df.count()))

Training Dataset Count: 84018
Validation Dataset Count: 35982


In [21]:
# actual content is inside description column
document = DocumentAssembler() \
.setInputCol("description") \
.setOutputCol("document") \
.setCleanupMode("shrink")

bert = BertSentenceEmbeddings.pretrained('sent_bert_base_cased') \
.setInputCols("document") \
.setOutputCol("bert_sentence_embeddings") \
.setLazyAnnotator(False)

# the classes/labels/categories are in category column
classifierdl = ClassifierDLApproach()\
.setInputCols(["bert_sentence_embeddings"])\
.setOutputCol("class")\
.setLabelColumn("category")\
.setMaxEpochs(4)\
.setLr(0.001)\
.setBatchSize(64)\
.setEnableOutputLogs(True)
#.setOutputLogsPath('logs')

pipeline = Pipeline(
    stages = [
        document,
        bert,
        classifierdl
    ])

sent_bert_base_cased download started this may take some time.
Approximate size to download 389.1 MB
[OK!]


In [22]:
%%time
pipelineModel = pipeline.fit(train_df)

CPU times: user 97.1 ms, sys: 22.2 ms, total: 119 ms
Wall time: 9min 30s


In [23]:
# get the predictions on validation Set

preds = pipelineModel.transform(val_df)

In [24]:
preds.select('description','category',"class.result").show(10, truncate=100)

+----------------------------------------------------------------------------------------------------+--------+----------+
|                                                                                         description|category|    result|
+----------------------------------------------------------------------------------------------------+--------+----------+
|  ''The Oprah Winfrey Show quot; was the best advertising an estimated  $8 million could buy for ...|Business|[Business]|
|  A  $120 million fine levied on Royal Dutch/Shell Group by the Securities and Exchange Commissio...|Business|[Business]|
|  A Colorado assistant store manager at Costco has filed a federal lawsuit, alleging she was pass...|Business|[Business]|
|  A drop in oil prices and upbeat outlooks from Wal Mart and Lowe's helped send stocks sharply hi...|Business|[Business]|
|  A federal bankruptcy judge ruled against United Airlines yesterday in a procedural dispute, sid...|Business|[Business]|
|  A federal ini

In [25]:
preds_df = preds.select('description','category',"class.result").toPandas()

# The result is an array since in Spark NLP you can have multiple sentences.
# Let's explode the array and get the item(s) inside of result column out
preds_df['result'] = preds_df['result'].apply(lambda x : x[0])

In [26]:
# We are going to use sklearn to evalute the results on test dataset
from sklearn.metrics import classification_report

print (classification_report(preds_df['result'], preds_df['category']))

              precision    recall  f1-score   support

    Business       0.86      0.85      0.86      8974
    Sci/Tech       0.89      0.85      0.87      9266
      Sports       0.98      0.95      0.97      9332
       World       0.87      0.94      0.90      8410

    accuracy                           0.90     35982
   macro avg       0.90      0.90      0.90     35982
weighted avg       0.90      0.90      0.90     35982

