# Finetune a text classification model
In this notebook, we will show how to finetune a `DistilledBert` model to classify SMS as spam or not.

In this guide we will load the SMS Spam Collection dataset from [DBFS](https://docs.databricks.com/dbfs/index.html) to show a full lifecycle of finetuning with Spark. You can also skip the DBFS part by directly loading SMS Spam Collection dataset from HuggingFace: [link](https://huggingface.co/datasets/sms_spam)

## Cluster setup
For this notebook, we recommend a single GPU cluster, such as a `g4dn.xlarge` on AWS or `Standard_NC4as_T4_v3` on Azure. You can [create a single machine cluster](https://docs.databricks.com/clusters/configure.html) using the personal compute policy or by choosing "Single Node" when creating a cluster. This notebook requires Databricks Runtime ML GPU version 11.1 or greater.

## Install dependencies

We need `datasets` and `evaluate` package by Huggingface. Additionally, TF 2.13 has a conflict with transformers <= 4.28, so w need to downgrade to TF 2.12.

In [46]:
!pip install pyspark
!pip install datasets



In [55]:
!pip install mlflow

Collecting mlflow
  Downloading mlflow-2.8.1-py3-none-any.whl (19.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.0/19.0 MB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
Collecting databricks-cli<1,>=0.8.7 (from mlflow)
  Downloading databricks_cli-0.18.0-py2.py3-none-any.whl (150 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.3/150.3 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
Collecting gitpython<4,>=2.1.0 (from mlflow)
  Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
Collecting alembic!=1.10.0,<2 (from mlflow)
  Downloading alembic-1.12.1-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.8/226.8 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker<7,>=4.0.0 (from mlflow)
  Downloading docker-6.1.3-py3-none-any.whl (148 kB)
[2K     [90m━━━━━━

In [None]:
!pip install -q datasets evaluate tensorflow==2.14.0

In [None]:
!pip install tensorflow==2.14.0



In [None]:
!pip install accelerate -U

Collecting accelerate
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.24.1


In [None]:
!pip install tensorflow_probability



In [None]:
!pip install --upgrade pyspark
!pip install --upgrade datasets

Collecting pyspark
  Downloading pyspark-3.5.0.tar.gz (316.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.9/316.9 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.0-py2.py3-none-any.whl size=317425344 sha256=4f1561f869840a5f0021b68c0742891f3021dafb394c02334a8fbd39de21b15c
  Stored in directory: /root/.cache/pip/wheels/41/4e/10/c2cf2467f71c678cfc8a6b9ac9241e5e44a01940da8fbb17fc
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.0


In [None]:
!pip install pyarrow



In [None]:
!wget http://setup.johnsnowlabs.com/colab.sh -O - | bash /dev/stdin -p 3.2.3 -s 5.1.4

--2023-11-21 15:28:36--  http://setup.johnsnowlabs.com/colab.sh
Resolving setup.johnsnowlabs.com (setup.johnsnowlabs.com)... 51.158.130.125
Connecting to setup.johnsnowlabs.com (setup.johnsnowlabs.com)|51.158.130.125|:80... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/scripts/colab_setup.sh [following]
--2023-11-21 15:28:36--  https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/scripts/colab_setup.sh
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1191 (1.2K) [text/plain]
Saving to: ‘STDOUT’


2023-11-21 15:28:36 (64.0 MB/s) - written to stdout [1191/1191]

Installing PySpark 3.2.3 and Spark NLP 5.1.4
setup Colab for PySpark 3.2.3 and Spark NLP 5

Restart the python runtime to use the updated dependencies.

## [Optional] Download data and copy to Databricks file system
Let's download and extract the dataset, we will use
[SMS Spam Collection Dataset](https://archive.ics.uci.edu/ml/datasets/sms+spam+collection) from the UCI Machine Learning Repository.

In [None]:
!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip

--2023-11-21 15:29:46--  https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified
Saving to: ‘smsspamcollection.zip’

smsspamcollection.z     [ <=>                ] 198.65K  1.01MB/s    in 0.2s    

2023-11-21 15:29:47 (1.01 MB/s) - ‘smsspamcollection.zip’ saved [203415]



In [None]:
!unzip -o smsspamcollection.zip

Archive:  smsspamcollection.zip
  inflating: SMSSpamCollection       
  inflating: readme                  


Copy the dataset to Databricks file system (DBFS). The `tutorial_path` sets the path in DBFS that the notebook uses to write the sample dataset. It is deleted by the last command in this notebook.

You can find the path to dataset by clicking on the triple dot next to `SMSSpamCollection` on the left sidebar, then Copy => Path.

In [None]:
#tutorial_path = "/FileStore/sms_tutorial"
#dbutils.fs.mkdirs(f"dbfs:{tutorial_path}")
#dbutils.fs.cp(
#    "file:/Workspace/Repos/chen.qian@databricks.com/mlflow-guide/finetune_spam_classifier/SMSSpamCollection",
#   f"dbfs:{tutorial_path}/SMSSpamCollection.tsv",
#) #   .config("spark.sql.execution.arrow.pyspark.enabled", "true") \

True

In [47]:
import findspark
findspark.init()

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("Spark NLP")\
    .master("local[*]")\
    .config("spark.driver.memory","16G")\
    .config("spark.driver.maxResultSize", "0") \
    .config("spark.kryoserializer.buffer.max", "2000M")\
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.1.4")\
    .getOrCreate()

Now your data lives in DBFS, we can load the dataset into a DataFrame. The file is tab separated and does not contain a header, so we specify the separator using `sep` and specify the column names explicitly.

In [None]:
sms = spark.read.csv(
    f"file:/content/SMSSpamCollection",
    header=False,
    inferSchema=True,
    sep="\t"
).toDF("label", "text")
print(f"Total number of data records: {sms.count()}")

# Print out some sample data.
sms.toPandas().head(10)

Total number of data records: 5574


Unnamed: 0,label,text
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."
5,spam,FreeMsg Hey there darling it's been 3 week's n...
6,ham,Even my brother is not like to speak with me. ...
7,ham,As per your request 'Melle Melle (Oru Minnamin...
8,spam,WINNER!! As a valued network customer you have...
9,spam,Had your mobile 11 months or more? U R entitle...



Convert string labels to integers, since finetuning requires an integer label.

In this exact dataset, we have the following mapping:
```
{
  "ham": 0,
  "spam": 1,
}
```

In [None]:
id2label = {0: "ham", 1: "spam"}
label2id = {'ham': 0, 'spam': 1}

Replace the string labels with the IDs in the DataFrame.

In [None]:
from pyspark.sql.functions import pandas_udf
import pandas as pd

# `pandas_udf` is the annotator that transforms a custom function
# into a udf, so we can call this function inside `select`.
@pandas_udf('integer')
def replace_labels_with_ids(labels: pd.Series) -> pd.Series:
  return labels.apply(lambda x: label2id[x])

sms_id_labels = sms.select(replace_labels_with_ids(sms.label).alias('label'), sms.text)
sms_id_labels.toPandas().head(10)

Unnamed: 0,label,text
0,0,"Go until jurong point, crazy.. Available only ..."
1,0,Ok lar... Joking wif u oni...
2,1,Free entry in 2 a wkly comp to win FA Cup fina...
3,0,U dun say so early hor... U c already then say...
4,0,"Nah I don't think he goes to usf, he lives aro..."
5,1,FreeMsg Hey there darling it's been 3 week's n...
6,0,Even my brother is not like to speak with me. ...
7,0,As per your request 'Melle Melle (Oru Minnamin...
8,1,WINNER!! As a valued network customer you have...
9,1,Had your mobile 11 months or more? U R entitle...


In [38]:
train.toPandas().head(10)

[Row(label=0, text=' &lt;#&gt;  in mca. But not conform.'),
 Row(label=0, text=' &lt;#&gt;  mins but i had to stop somewhere first.'),
 Row(label=0, text=' &lt;DECIMAL&gt; m but its not a common car here so its better to buy from china or asia. Or if i find it less expensive. I.ll holla'),
 Row(label=0, text=' and  picking them up from various points'),
 Row(label=0, text=' came to look at the flat, seems ok, in his 50s? * Is away alot wiv work. Got woman coming at 6.30 too.'),
 Row(label=0, text=" gonna let me know cos comes bak from holiday that day.  is coming. Don't4get2text me  number. "),
 Row(label=0, text=" said kiss, kiss, i can't do the sound effects! He is a gorgeous man isn't he! Kind of person who needs a smile to brighten his day! "),
 Row(label=0, text=" says that he's quitting at least5times a day so i wudn't take much notice of that. Nah, she didn't mind. Are you gonna see him again? Do you want to come to taunton tonight? U can tell me all about !"),
 Row(label=0, tex

In [48]:
from datasets import Dataset
import pandas as pd
# Split the DataFrame into train and test
#(train, test) = sms_id_labels.randomSplit([0.8, 0.2])

# Convert Spark DataFrames to pandas DataFrames
train_pd = train.toPandas()
test_pd = test.toPandas()

# Convert pandas DataFrames to HuggingFace datasets
train_dataset = Dataset.from_pandas(train_pd)
test_dataset = Dataset.from_pandas(test_pd)

Now let's convert the dataframe into a HuggingFace dataset. HuggingFace supports loading from Spark DataFrames using `datasets.Dataset.from_spark`. See the Hugging Face documentation to learn more about the [from_spark()](https://huggingface.co/docs/datasets/use_with_spark) method.

Dataset.from_spark caches the dataset. In this example, the model is trained on the driver, and the cached data is parallelized using Spark, so `cache_dir` must be accessible to the driver and to all the workers. You can use the Databricks File System (DBFS) root([AWS](https://docs.databricks.com/dbfs/index.html#what-is-the-dbfs-root)| [Azure](https://learn.microsoft.com/azure/databricks/dbfs/#what-is-the-dbfs-root) |[GCP](https://docs.gcp.databricks.com/dbfs/index.html#what-is-the-dbfs-root)) or mount point ([AWS](https://docs.databricks.com/dbfs/mounts.html) | [Azure](https://learn.microsoft.com/azure/databricks/dbfs/mounts) | [GCP](https://docs.gcp.databricks.com/dbfs/mounts.html)).

By using DBFS, you can reference "local" paths when creating the `transformers` compatible datasets used for model training.

## Alternative way to load dataset

If you skip the previous step to load dataset from Spark, uncomment and run the command below to load dataset directly from HuggingFace.

In [None]:
# from datasets import load_dataset

# sms_dataset = load_dataset("sms_spam")
# sms_train_test = sms_dataset["train"].train_test_split(test_size=0.2)
# # For consistency, we rename "sms" => "text".
# sms_train_test = sms_train_test.rename_column("sms", "text")
# train_dataset = sms_train_test["train"]
# test_dataset = sms_train_test["test"]



  0%|          | 0/1 [00:00<?, ?it/s]

## Data preprocessing

Before finetuning, let's tokenize and shuffle the datasets. Since the [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer) does not need the untokenized `text` columns for training,
the notebook removes them from the dataset.
In this step, `datasets` also caches the transformed datasets on local disk for fast subsequent loading during model training.

In [49]:
from transformers import AutoTokenizer

# Load the tokenizer for "distilbert-base-uncased" model.
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
def tokenize_function(examples):
    # Pad/truncate each text to 512 tokens. Enforcing the same shape
    # could make the training faster.
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
    )

train_tokenized = train_dataset.map(tokenize_function).remove_columns(["text"]).shuffle(seed=42)
test_tokenized = test_dataset.map(tokenize_function).remove_columns(["text"]).shuffle(seed=42)

Map:   0%|          | 0/4468 [00:00<?, ? examples/s]

Map:   0%|          | 0/1106 [00:00<?, ? examples/s]

# Model finetuning

We have prepared the data, let's kick off the finetuning!

For finetuning we will rely on HuggingFace `Trainer` API.

Create the evaluation metric to log. Loss is also logged, but adding other metrics such as accuracy can make modeling performance easier to understand. For classification task, we use `accuracy` as the tracking metric.

In [50]:
import numpy as np
import evaluate
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Set training arguments.
Please refer to [transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments)
for the full arg list. Don't panick on the long list of args, usually we just need a few out of that.

**Important: Note that you cannot set `training_output_dir` in the working directory due to the writing restriction, we recommend using some directory under `/tmp`.**

In [51]:
from transformers import TrainingArguments, Trainer

# Set the output directory to somewhere inside /tmp.
training_output_dir = "/tmp/weird_mouse/sms_trainer"
training_args = TrainingArguments(
    output_dir=training_output_dir,
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
)

Let's load the pretrained Distilled Bert model, and specify the label mappings and the number of classes.

In [52]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=2,
    label2id=label2id,
    id2label=id2label,
)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier.weight', 'pre_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Construct the trainer object with the model, arguments, datasets, collator, and metrics created above.

In [53]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    compute_metrics=compute_metrics,
)

In [57]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.0567,0.029554,0.993671
2,0.0188,0.031234,0.994575
3,0.0024,0.033895,0.994575


TrainOutput(global_step=1677, training_loss=0.02342929430283744, metrics={'train_runtime': 11522.3626, 'train_samples_per_second': 1.163, 'train_steps_per_second': 0.146, 'total_flos': 481518994114560.0, 'train_loss': 0.02342929430283744, 'epoch': 3.0})

Train the model, meanwhile we log metrics and results to MLflow.


Let's wrap the model into a HuggingFace `text-classification` pipeline so that we can directly feed text data for spam classification.

In [59]:
# Evaluate the model
eval_result = trainer.evaluate(eval_dataset=test_tokenized)

# `Trainer.evaluate` returns a dictionary containing the evaluation metrics
print(f"Evaluation result: {eval_result}")

Evaluation result: {'eval_loss': 0.033895451575517654, 'eval_accuracy': 0.9945750452079566, 'eval_runtime': 250.6172, 'eval_samples_per_second': 4.413, 'eval_steps_per_second': 0.555, 'epoch': 3.0}


In [60]:
accuracy = eval_result['eval_accuracy']

# Now print the accuracy
print(f"Accuracy: {accuracy}")

Accuracy: 0.9945750452079566


The drastic improvement in Accuracy