# Prepare the dataset

We load a dataset with messages from customers that are assigned to different categories (invoices, orders, etc.)

We will use this dataset to tune a prompt that can classify a message into one of these categories, with an accuracy as high as possible.

https://huggingface.co/datasets/bitext/Bitext-customer-support-llm-chatbot-training-dataset

In [28]:
pip install datasets --quiet

Note: you may need to restart the kernel to use updated packages.


In [29]:
from datasets import load_dataset

dataset = load_dataset(
    "bitext/Bitext-customer-support-llm-chatbot-training-dataset", split="train"
)
dataset

Dataset({
    features: ['flags', 'instruction', 'category', 'intent', 'response'],
    num_rows: 26872
})

In [30]:
CLASSES = list(set(dataset["category"]))
CLASSES

['INVOICE',
 'PAYMENT',
 'SUBSCRIPTION',
 'FEEDBACK',
 'CANCEL',
 'DELIVERY',
 'ORDER',
 'SHIPPING',
 'REFUND',
 'CONTACT',
 'ACCOUNT']

### CONSTANTS

In [48]:
N = 200  # Number of samples per class / category
TRAIN_TEST_SPLIT = 0.75  # split dataset in trian and test
TEST_VAL_SPLIT = 0.7  # spit test in test and validation

In [49]:
import pandas as pd
from datasets import Dataset

# Convert the dataset to a pandas DataFrame
df = pd.DataFrame(dataset)
# add an id column
df = df.reset_index(drop=False).rename(columns={"index": "id"})
# Group by the label column
grouped = df.groupby("category")
# Sample 100 records from each label
sampled_df = grouped.apply(lambda x: x.sample(n=N, random_state=42)).reset_index(
    drop=True
)
# Convert the sampled DataFrame back to a Hugging Face dataset
sampled_dataset = Dataset.from_pandas(sampled_df)
print(sampled_df["category"].value_counts())
shuffled_dataset = sampled_dataset.shuffle()

category
ACCOUNT         200
CANCEL          200
CONTACT         200
DELIVERY        200
FEEDBACK        200
INVOICE         200
ORDER           200
PAYMENT         200
REFUND          200
SHIPPING        200
SUBSCRIPTION    200
Name: count, dtype: int64


  sampled_df = grouped.apply(lambda x: x.sample(n=N, random_state=42)).reset_index(drop=True)


Select only the relevant columns and rename them according to the class ClassificationDataClass

In [50]:
selected_cols_dataset = shuffled_dataset.select_columns(
    ["category", "instruction", "id"]
)
renamed_dataset = selected_cols_dataset.rename_column(
    "category", "class_name"
).rename_column("instruction", "question")
full_dataset = renamed_dataset
full_dataset.to_csv("full_dataset.csv")
full_dataset

Creating CSV from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Dataset({
    features: ['class_name', 'question', 'id'],
    num_rows: 2200
})

In [51]:
# Split the dataset into train and test (70% train, 30% test)
train_test_split = full_dataset.train_test_split(test_size=TRAIN_TEST_SPLIT)
# Further split the test set into validation and test sets (50% validation, 50% test)
val_test_split = train_test_split["test"].train_test_split(test_size=TEST_VAL_SPLIT)

train = train_test_split["train"]
train.to_csv("train.csv")
val = val_test_split["train"]
val.to_csv("val.csv")
test = val_test_split["test"]
test.to_csv("test.csv")

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

71935

In [52]:
train.to_pandas()["class_name"].value_counts()

class_name
INVOICE         57
PAYMENT         56
DELIVERY        55
FEEDBACK        55
REFUND          55
ACCOUNT         50
CONTACT         49
CANCEL          46
SHIPPING        43
ORDER           42
SUBSCRIPTION    42
Name: count, dtype: int64

In [53]:
val.to_pandas()["class_name"].value_counts()

class_name
CANCEL          55
CONTACT         55
DELIVERY        49
SUBSCRIPTION    48
INVOICE         44
ACCOUNT         44
REFUND          43
SHIPPING        40
FEEDBACK        39
ORDER           39
PAYMENT         39
Name: count, dtype: int64

In [54]:
test.to_pandas()["class_name"].value_counts()

class_name
ORDER           119
SHIPPING        117
SUBSCRIPTION    110
ACCOUNT         106
FEEDBACK        106
PAYMENT         105
REFUND          102
CANCEL           99
INVOICE          99
CONTACT          96
DELIVERY         96
Name: count, dtype: int64