In [None]:
from src.train import eval_model

test_acc, _ = eval_model(
    model,
    test_data_loader,
    loss_fn,
    device,
    len(df_test)
)

test_acc.item()

This *test accuracy* score is within 1 percent of the validation accuracy score from our peak performing epoch of the most recent training. Thus, the *test accuracy* reliably predicts our *test accuracy* (and by extension our real-world accuracy).


*Helper function for extracting predicted probabilities for each text description*


In [None]:
def get_predictions(model, data_loader):
    model = model.eval()

    description_texts = []
    predictions = []
    prediction_probs = []
    real_values = []

    with torch.no_grad():
        for d in data_loader:
            texts = d['description_text']
            input_ids = d['input_ids'].to(device)
            attention_mask = d['attention_mask'].to(device)
            targets = d['targets'].to(device)

            # perform forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            _, preds = torch.max(outputs, dim=1)

            description_texts.extend(texts)
            predictions.extend(preds)
            prediction_probs.extend(outputs)
            real_values.extend(targets)

            # NOTE. We use extend() here since texts, preds, outputs and targets represent sequences themselves whereas we would use append() for adding single items to a list.

        predictions = torch.stack(predictions).cpu()
        prediction_probs = torch.stack(prediction_probs).cpu()
        real_values = torch.stack(real_values).cpu()

        return description_texts, predictions, prediction_probs, real_values


*Classificaiton Report*

In [None]:
from sklearn.metrics import classification_report

y_description_texts, y_pred, y_pred_probs, y_test = get_predictions(
    model, test_data_loader
)

class_names = ['World', 'Sports', 'Business', 'Sci/Tech']

print(classification_report(y_test, y_pred, target_names=class_names))

From the report above,
+ The model classifies **Sports** and **World** news articles best while classifying **Business** and **Sci/Tech** articles marginally less accurately.

*Confusion Matrix*

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# theme/config for heatmap
sns.set_theme(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

def show_confusion_matrix(confusion_matrix):
    hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
    hmap.axes.set_yticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
    hmap.axes.set_xticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')

    plt.ylabel('Actual topic')
    plt.xlabel('Predicted topic')

cm = confusion_matrix(y_test, y_pred)
df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
show_confusion_matrix(df_cm)

*Note*. &nbsp;&nbsp;&nbsp;&nbsp; Here we use `sns.heatmap` from the `seaborn` library (which is mainly used for creating Python data visualizations (built on top of `matplotlib`)). The heatmap is created of the confusion matrix, a matrix of correct prediction instances. The heatmap confirms the model's marginal difficulty at classifying **Business** and **Sci/Tech** news relative to the other topics.

Taking a closer look at an example batch from our test data,

In [None]:
idx = 6998 # arbitrary int in [0, 7600]

description_text = y_description_texts[idx]
actual_topic = y_test[idx]
pred_df = pd.DataFrame({
    'class_names': class_names,
    'values': y_pred_probs[idx],
})

In [None]:
from textwrap import wrap

print("\n".join(wrap(description_text)))
print()
print(f"Actual topic: {class_names[actual_topic]}")

Continuing, we can examine the *confidence* of each topic suggested by the model:

In [None]:
sns.barplot(x='values', y='class_names', data=pred_df, orient='h')
plt.ylabel('topic')
plt.xlabel('probability')
plt.xlim([0, 1])

From manually testing multiple different indices in the range [0, 7600] the model appears nearly 100 percent confident in most cases. Rarely, the model will produce this high confidence for two competing topics which the model considers.