In [3]:
# Import necessary libraries
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support


In [None]:
# Generate 40 small support cases with labels
data = [
    {'text': 'The motor is overheating after prolonged use.', 'label': 'motor'},
    {'text': 'Printhead alignment is off, causing blurry prints.', 'label': 'printhead'},
    {'text': 'The printer shows an unknown error code.', 'label': 'other'},
    {'text': 'Motor makes a grinding noise during operation.', 'label': 'motor'},
    {'text': 'Printhead nozzles are clogged and need cleaning.', 'label': 'printhead'},
    {'text': 'Cannot connect the printer to Wi-Fi.', 'label': 'other'},
    {'text': 'Motor fails to start when initiating a print job.', 'label': 'motor'},
    {'text': 'Printhead alignment utility fails every time.', 'label': 'printhead'},
    {'text': 'Paper jams occur frequently in the tray.', 'label': 'other'},
    {'text': 'Motor belt is worn out and slips.', 'label': 'motor'},
    {'text': 'Printhead produces streaks across all prints.', 'label': 'printhead'},
    {'text': 'The touch screen is unresponsive.', 'label': 'other'},
    {'text': 'Motor overheats and shuts down.', 'label': 'motor'},
    {'text': 'Printhead is not recognized by the printer.', 'label': 'printhead'},
    {'text': 'Printer driver installation fails.', 'label': 'other'},
    {'text': 'Motor emits a burning smell during use.', 'label': 'motor'},
    {'text': 'Printhead colors are inaccurate and faded.', 'label': 'printhead'},
    {'text': 'Cannot update printer firmware.', 'label': 'other'},
    {'text': 'Motor vibrates excessively when printing.', 'label': 'motor'},
    {'text': 'Printhead cleaning cycle does not complete.', 'label': 'printhead'},
    {'text': 'The motor is not responding to commands.', 'label': 'motor'},
    {'text': 'Printhead is scratching the paper surface.', 'label': 'printhead'},
    {'text': 'Unable to print from mobile devices.', 'label': 'other'},
    {'text': 'Motor speed is inconsistent during operation.', 'label': 'motor'},
    {'text': 'Printhead calibration fails to improve quality.', 'label': 'printhead'},
    {'text': 'Error message says "cartridge not recognized".', 'label': 'other'},
    {'text': 'Motor bearings are worn out.', 'label': 'motor'},
    {'text': 'Printhead drips ink even when idle.', 'label': 'printhead'},
    {'text': 'Printer disconnects from network intermittently.', 'label': 'other'},
    {'text': 'Motor consumes excessive power.', 'label': 'motor'},
    {'text': 'Printhead alignment is perfect but colors are off.', 'label': 'printhead'},
    {'text': 'Cannot access printer settings via control panel.', 'label': 'other'},
    {'text': 'Motor controller board needs replacement.', 'label': 'motor'},
    {'text': 'Printhead cleaning does not resolve smudging.', 'label': 'printhead'},
    {'text': 'Printer tray does not hold paper properly.', 'label': 'other'},
    {'text': 'Motor makes unusual humming sounds.', 'label': 'motor'},
    {'text': 'Printhead requires frequent replacement.', 'label': 'printhead'},
    {'text': 'Cannot perform duplex printing.', 'label': 'other'},
    {'text': 'Motor wiring harness is damaged.', 'label': 'motor'},
    {'text': 'Printhead temperature exceeds normal levels.', 'label': 'printhead'},
]

# Map labels to integers
label_mapping = {'motor': 0, 'printhead': 1, 'other': 2}
texts = [item['text'] for item in data]
labels = [label_mapping[item['label']] for item in data]

# Split data into training and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42, stratify=labels
)

# Generate 10 test examples for prediction
test_data = [
    'Printhead leaks ink onto the paper.',
    'Motor starts but then stops immediately.',
    'Unable to scan documents to email.',
    'Motor shaft is misaligned and needs adjustment.',
    'Printhead alignment successful but quality is poor.',
    'Cannot connect to the printer via Bluetooth.',
    'Motor overheating error appears after startup.',
    'Printhead produces faded text on the left side.',
    'Paper feeder is not pulling paper correctly.',
    'Motor drive belt snapped during operation.',
]

# Load pre-trained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=3)

# Tokenize the training and validation data
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)

# Create a custom Dataset class
class SupportCasesDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    def __len__(self):
        return len(self.labels)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }


# Prepare the datasets
train_dataset = SupportCasesDataset(train_encodings, train_labels)
val_dataset = SupportCasesDataset(val_encodings, val_labels)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          # Output directory
    num_train_epochs=10,             # Number of training epochs
    per_device_train_batch_size=4,   # Batch size per device
    evaluation_strategy='epoch',     # Evaluate at each epoch
    logging_dir='./logs',            # Logging directory
    logging_steps=10,
    seed=42,
)

# Create a Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

# Evaluate the model
eval_results = trainer.evaluate()
print(f"Validation Loss: {eval_results['eval_loss']:.4f}")

# Prepare the test data
test_encodings = tokenizer(test_data, truncation=True, padding=True)
test_dataset = SupportCasesDataset(test_encodings, [0]*len(test_data))  # Dummy labels

# Make predictions
predictions = trainer.predict(test_dataset)
pred_labels = np.argmax(predictions.predictions, axis=1)

# Map predicted labels back to class names
inv_label_mapping = {v: k for k, v in label_mapping.items()}
predicted_classes = [inv_label_mapping[label] for label in pred_labels]

# Print the predictions
for text, pred_class in zip(test_data, predicted_classes):
    print(f'Text: "{text}"')
    print(f'Predicted class: {pred_class}')
    print('---')



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                              
 10%|█         | 8/80 [00:19<02:37,  2.19s/it]

{'eval_loss': 0.8828052878379822, 'eval_accuracy': 0.625, 'eval_precision': 0.85, 'eval_recall': 0.625, 'eval_f1': 0.6303571428571428, 'eval_runtime': 0.2653, 'eval_samples_per_second': 30.15, 'eval_steps_per_second': 3.769, 'epoch': 1.0}


 12%|█▎        | 10/80 [00:23<02:25,  2.07s/it]

{'loss': 0.9949, 'grad_norm': 8.293554306030273, 'learning_rate': 4.375e-05, 'epoch': 1.25}


 20%|██        | 16/80 [00:34<01:59,  1.86s/it]
 20%|██        | 16/80 [00:35<01:59,  1.86s/it]

{'eval_loss': 0.4597586989402771, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.2415, 'eval_samples_per_second': 33.132, 'eval_steps_per_second': 4.142, 'epoch': 2.0}


 25%|██▌       | 20/80 [00:42<01:57,  1.96s/it]

{'loss': 0.5256, 'grad_norm': 8.005005836486816, 'learning_rate': 3.7500000000000003e-05, 'epoch': 2.5}


 30%|███       | 24/80 [00:50<01:47,  1.92s/it]
 30%|███       | 24/80 [00:50<01:47,  1.92s/it]

{'eval_loss': 0.1681937873363495, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.2986, 'eval_samples_per_second': 26.792, 'eval_steps_per_second': 3.349, 'epoch': 3.0}


 38%|███▊      | 30/80 [01:02<01:37,  1.95s/it]

{'loss': 0.2042, 'grad_norm': 3.780911445617676, 'learning_rate': 3.125e-05, 'epoch': 3.75}


 40%|████      | 32/80 [01:06<01:41,  2.12s/it]
 40%|████      | 32/80 [01:07<01:41,  2.12s/it]

{'eval_loss': 0.06567766517400742, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.3354, 'eval_samples_per_second': 23.85, 'eval_steps_per_second': 2.981, 'epoch': 4.0}


 50%|█████     | 40/80 [01:24<01:25,  2.13s/it]

{'loss': 0.0838, 'grad_norm': 1.0667589902877808, 'learning_rate': 2.5e-05, 'epoch': 5.0}



 50%|█████     | 40/80 [01:24<01:25,  2.13s/it]

{'eval_loss': 0.03519752621650696, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.3289, 'eval_samples_per_second': 24.324, 'eval_steps_per_second': 3.041, 'epoch': 5.0}


 60%|██████    | 48/80 [01:41<01:06,  2.09s/it]
 60%|██████    | 48/80 [01:41<01:06,  2.09s/it]

{'eval_loss': 0.015720052644610405, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.3142, 'eval_samples_per_second': 25.459, 'eval_steps_per_second': 3.182, 'epoch': 6.0}


 62%|██████▎   | 50/80 [01:46<01:05,  2.17s/it]

{'loss': 0.0275, 'grad_norm': 0.3481942415237427, 'learning_rate': 1.8750000000000002e-05, 'epoch': 6.25}


 70%|███████   | 56/80 [02:00<00:58,  2.43s/it]
 70%|███████   | 56/80 [02:00<00:58,  2.43s/it]

{'eval_loss': 0.009631572291254997, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.4624, 'eval_samples_per_second': 17.301, 'eval_steps_per_second': 2.163, 'epoch': 7.0}


 75%|███████▌  | 60/80 [02:11<00:53,  2.70s/it]

{'loss': 0.0141, 'grad_norm': 0.20672811567783356, 'learning_rate': 1.25e-05, 'epoch': 7.5}


 80%|████████  | 64/80 [02:22<00:41,  2.62s/it]
 80%|████████  | 64/80 [02:22<00:41,  2.62s/it]

{'eval_loss': 0.007248944137245417, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.4469, 'eval_samples_per_second': 17.9, 'eval_steps_per_second': 2.238, 'epoch': 8.0}


 88%|████████▊ | 70/80 [02:37<00:24,  2.47s/it]

{'loss': 0.0096, 'grad_norm': 0.2077576071023941, 'learning_rate': 6.25e-06, 'epoch': 8.75}


 90%|█████████ | 72/80 [02:42<00:19,  2.38s/it]
 90%|█████████ | 72/80 [02:42<00:19,  2.38s/it]

{'eval_loss': 0.006440388038754463, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.3674, 'eval_samples_per_second': 21.772, 'eval_steps_per_second': 2.721, 'epoch': 9.0}


100%|██████████| 80/80 [02:56<00:00,  1.84s/it]

{'loss': 0.0083, 'grad_norm': 0.1477527767419815, 'learning_rate': 0.0, 'epoch': 10.0}


                                               
100%|██████████| 80/80 [02:59<00:00,  2.24s/it]


{'eval_loss': 0.00618194742128253, 'eval_accuracy': 1.0, 'eval_precision': 1.0, 'eval_recall': 1.0, 'eval_f1': 1.0, 'eval_runtime': 0.2497, 'eval_samples_per_second': 32.042, 'eval_steps_per_second': 4.005, 'epoch': 10.0}
{'train_runtime': 179.4024, 'train_samples_per_second': 1.784, 'train_steps_per_second': 0.446, 'train_loss': 0.23349415017291902, 'epoch': 10.0}


100%|██████████| 1/1 [00:00<00:00, 333.36it/s]


Validation Loss: 0.0062


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
100%|██████████| 2/2 [00:00<00:00, 12.70it/s]

Text: "Printhead leaks ink onto the paper."
Predicted class: printhead
---
Text: "Motor starts but then stops immediately."
Predicted class: motor
---
Text: "Unable to scan documents to email."
Predicted class: other
---
Text: "Motor shaft is misaligned and needs adjustment."
Predicted class: motor
---
Text: "Printhead alignment successful but quality is poor."
Predicted class: printhead
---
Text: "Cannot connect to the printer via Bluetooth."
Predicted class: other
---
Text: "Motor overheating error appears after startup."
Predicted class: motor
---
Text: "Printhead produces faded text on the left side."
Predicted class: printhead
---
Text: "Paper feeder is not pulling paper correctly."
Predicted class: other
---
Text: "Motor drive belt snapped during operation."
Predicted class: motor
---





In [5]:
eval_results = trainer.evaluate()
print(f"Validation Loss: {eval_results['eval_loss']:.4f}")
print(f"Accuracy: {eval_results['eval_accuracy']:.4f}")
print(f"Precision: {eval_results['eval_precision']:.4f}")
print(f"Recall: {eval_results['eval_recall']:.4f}")
print(f"F1 Score: {eval_results['eval_f1']:.4f}")

100%|██████████| 1/1 [00:00<00:00, 249.50it/s]

Validation Loss: 0.0062
Accuracy: 1.0000
Precision: 1.0000
Recall: 1.0000
F1 Score: 1.0000



