In [50]:
import pandas as pd

# https://huggingface.co/datasets/rajpurkar/squad_v2
splits = {'train': 'squad_v2/train-00000-of-00001.parquet', 'validation': 'squad_v2/validation-00000-of-00001.parquet'}
df_squad_v2 = pd.read_parquet("hf://datasets/rajpurkar/squad_v2/" + splits["train"])

In [51]:
import numpy as np

# Pick 150 random questions, then divide in two groups of 100 and 50 to save in two variables
# Double square parenthesis to get a dataframe
sampled_ot_examples = df_squad_v2.sample(150, random_state=34197)[['question']]
sampled_ot_examples.rename(columns={'question': 'Question'}, inplace=True)

sampled_ot_examples['Global Subject'] = 'off_topic'
sampled_ot_examples['Question Intent'] = 'off_topic'

sampled_ot_examples_100 = sampled_ot_examples[:100]
sampled_ot_examples_50 = sampled_ot_examples[100:]

In [52]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from BertClassifierModelTrainer import BertClassifierModelTrainer

df = pd.read_csv('./labelling/data_cleaned_manual.csv')
ds_automaton = pd.read_csv('./new_questions/automaton_questions.csv')
ds_state = pd.read_csv('./new_questions/state_questions.csv')
ds_transition = pd.read_csv('./new_questions/transition_questions.csv')
ds_grammar = pd.read_csv('./new_questions/grammar_questions.csv')

df = pd.concat([df, ds_automaton, ds_state, ds_transition, ds_grammar, sampled_ot_examples_100], ignore_index=True)

le = LabelEncoder()
df['gs'] = le.fit_transform(df['Global Subject'])
df['qi'] = le.fit_transform(df['Question Intent'])

label_count = df['gs'].nunique()

train_questions, val_questions, train_labels, val_labels = train_test_split(
    df['Question'], df['gs'],
    test_size=0.2,
    random_state=34197,
    stratify=df['gs']
)

model_name = 'distilbert-base-uncased'

trainer = BertClassifierModelTrainer(model_name, label_count, 2e-5,
                                     train_questions.tolist(),
                                     train_labels,
                                     val_questions.tolist(),
                                     val_labels)
trainer.train()

df.head()

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training model...: 100%|██████████| 20/20 [00:19<00:00,  1.05epochs/s]


Unnamed: 0,Question,Global Subject,Question Intent,gs,qi
0,Hi,start,greet,3,20
1,Hello,start,greet,3,20
2,Describe the automaton,automaton,description,0,6
3,Is there a transition between q2 and q0?,transition,existence_between,6,12
4,Is there a transition between q5 and q7,transition,existence_between,6,12


In [53]:
trainer.stats

Unnamed: 0_level_0,training_loss,validation_loss,validation_accuracy,epoch_duration,true_labels_distribution,predicted_labels_distribution
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,1.882346,1.710048,0.358108,0.947951,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{6: 136, 1: 9, 2: 3}"
2,1.463576,1.095645,0.783784,0.941942,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{6: 56, 0: 24, 1: 38, 2: 21, 4: 9}"
3,0.864769,0.628628,0.844595,0.954141,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{1: 30, 0: 23, 6: 52, 2: 22, 4: 21}"
4,0.466599,0.44674,0.885135,0.949878,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{4: 24, 0: 25, 6: 47, 1: 30, 2: 22}"
5,0.286384,0.433209,0.858108,0.943125,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{0: 30, 6: 48, 1: 29, 2: 23, 4: 18}"
6,0.212533,0.449515,0.878378,0.946184,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{0: 26, 6: 49, 1: 31, 2: 21, 4: 21}"
7,0.164608,0.470299,0.858108,0.946136,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{0: 35, 6: 45, 1: 30, 2: 21, 4: 17}"
8,0.139086,0.493347,0.858108,0.949147,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{0: 30, 6: 49, 1: 31, 2: 20, 4: 18}"
9,0.106685,0.434713,0.891892,0.946367,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{1: 32, 0: 27, 6: 44, 2: 21, 4: 24}"
10,0.103223,0.456411,0.878378,0.946141,"{4: 21, 0: 28, 6: 45, 1: 29, 2: 21, 5: 3, 3: 1}","{0: 29, 6: 47, 1: 31, 2: 20, 4: 21}"


In [54]:
import matplotlib.pyplot as plt

%matplotlib notebook

# Plot the training and validation loss over epochs
plt.figure(figsize=(12, 6))

# Plot loss
plt.plot(trainer.stats['training_loss'], label='Training Loss')
plt.plot(trainer.stats['validation_loss'], label='Validation Loss')

# Format the plot
plt.title('Training & Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

<IPython.core.display.Javascript object>

In [55]:
# Plot the validation accuracy over epochs
plt.figure(figsize=(12, 6))

# Plot accuracy
plt.plot(trainer.stats['validation_accuracy'], label='Validation Accuracy', color='green')

# Format the plot
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()

<IPython.core.display.Javascript object>

In [66]:
import torch

sample_text = "what is a finite state machine?"

# Extract unique pairs of label IDs and text labels
unique_labels = df[['gs', 'Global Subject']].drop_duplicates()

# Create a dictionary mapping from label ID to text label
label_map = dict(zip(unique_labels['gs'], unique_labels['Global Subject']))

top_predictions = trainer.predict_top(sample_text, label_map, top_k=5)
print(f'Top Predictions Confidence')
for label, confidence in top_predictions:
    #print with confidence at 3 decimal places
    print(f'{label}: {confidence:.3f}')

# Overfitting? Maybe use a lower training rate? One-shot training?

#[('transition', 0.9239075779914856),
# ('state', 0.04130719602108002),
# ('theory', 0.00945412926375866),
# ('automaton', 0.008357021026313305),
# ('off_topic', 0.007631700951606035)]


Top Predictions Confidence
theory: 0.796
automaton: 0.055
transition: 0.045
state: 0.040
off_topic: 0.033
