In [None]:
import numpy as np
from sklearn.datasets import load_files
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Bidirectional, LSTM, Dense, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.regularizers import l2
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import pandas as pd


# Once data has finished processing, load data from folder structure
# text-data/
# ├─ advertisement/
# ├─ email/
# ├─ invoice/
# ....
data = load_files('../text-data', encoding='utf-8', decode_error='ignore')

X = data.data               
y = data.target             
class_names = data.target_names 

# split data for training and testing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
# Tokenize
max_words = 20000
tokenizer = Tokenizer(
    num_words=max_words,
    filters='\'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
    lower=True
)
tokenizer.fit_on_texts(X_train)

X_train_seq = tokenizer.texts_to_sequences(X_train)
X_test_seq = tokenizer.texts_to_sequences(X_test)

# Pad sequences
max_len = 200
X_train_pad = pad_sequences(X_train_seq, maxlen=max_len, padding='post', truncating='post')
X_test_pad = pad_sequences(X_test_seq, maxlen=max_len, padding='post', truncating='post')

# One-hot encode labels
num_classes = len(class_names)
y_train_cat = to_categorical(y_train, num_classes=num_classes)
y_test_cat = to_categorical(y_test, num_classes=num_classes)

In [None]:
# Build the model
model = Sequential()
model.add(Embedding(input_dim=max_words, output_dim=128, embeddings_regularizer=l2(0.001)))
model.add(Bidirectional(LSTM(64, return_sequences=False, dropout=0.1, recurrent_dropout=0.1)))
model.add(Dropout(0.25))
model.add(Dense(num_classes, activation='softmax'))

model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)
model.build(input_shape=(None, max_len))
model.summary()

# Train
history = model.fit(
    X_train_pad, y_train_cat,
    epochs=10,
    batch_size=32,
    validation_split=0.1,
    shuffle=True
)

# Evaluate
loss, acc = model.evaluate(X_test_pad, y_test_cat, verbose=1)
print("Test accuracy:", acc)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# plot accuracy
axes[0].plot(history.history['accuracy'], label='Train Accuracy')
axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy')
axes[0].set_title('Model Accuracy')
axes[0].set_xlabel('Epoch') 
axes[0].set_ylabel('Accuracy')
axes[0].legend()

# plot loss
axes[1].plot(history.history['loss'], label='Train Loss')
axes[1].plot(history.history['val_loss'], label='Validation Loss')
axes[1].set_title('Model Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()

plt.tight_layout()
plt.savefig('training_history.png')
plt.show()

In [None]:
y_pred = model.predict(X_test_pad)
y_pred_classes = np.argmax(y_pred, axis=1)

# create confusion matrix
cm = confusion_matrix(y_test, y_pred_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.show()

In [None]:
# Get predictions and misclassifications
y_pred = model.predict(X_test_pad)
y_pred_classes = np.argmax(y_pred, axis=1)
y_pred_probs = np.max(y_pred, axis=1)  # confidence scores

misclassified_mask = y_test != y_pred_classes
misclassified_indices = np.where(misclassified_mask)[0]

print(f"Total misclassified samples: {len(misclassified_indices)} out of {len(y_test)}")
print(f"Misclassification rate: {len(misclassified_indices)/len(y_test)*100:.2f}%\n")

# Classification report
report = classification_report(y_test, y_pred_classes, target_names=class_names)
print("Classification Report:")
print(report)

# Analyze misclassification patterns
misclass_df = pd.DataFrame({
    'true_label': [class_names[y_test[i]] for i in misclassified_indices],
    'predicted_label': [class_names[y_pred_classes[i]] for i in misclassified_indices],
    'confidence': y_pred_probs[misclassified_indices],
    'text': [X_test[i][:200] if isinstance(X_test[i], str) else X_test[i].decode('utf-8', errors='ignore')[:200] 
             for i in misclassified_indices]
})

# Most common misclassification pairs
misclass_pairs = misclass_df.groupby(['true_label', 'predicted_label']).size().reset_index(name='count')
misclass_pairs = misclass_pairs.sort_values('count', ascending=False)

print("TOP 10 MISCLASSIFICATION PATTERNS")
print(misclass_pairs.head(10).to_string(index=False))

# Visualize misclassification patterns
plt.figure(figsize=(12, 8))
top_pairs = misclass_pairs.head(15)
plt.barh(range(len(top_pairs)), top_pairs['count'])
plt.yticks(range(len(top_pairs)), 
           [f"{row['true_label']} → {row['predicted_label']}" 
            for _, row in top_pairs.iterrows()])
plt.xlabel('Number of Misclassifications')
plt.title('Top 15 Misclassification Patterns')
plt.tight_layout()
plt.savefig('misclassification_patterns_rnn.png', dpi=300, bbox_inches='tight')
plt.show()

# Per-class error analysis
error_by_class = pd.DataFrame({
    'class': class_names,
    'total_samples': [np.sum(y_test == i) for i in range(len(class_names))],
    'misclassified': [np.sum((y_test == i) & (y_pred_classes != i)) for i in range(len(class_names))]
})
error_by_class['error_rate'] = error_by_class['misclassified'] / error_by_class['total_samples'] * 100
error_by_class = error_by_class.sort_values('error_rate', ascending=False)

print("\n" + "="*80)
print("ERROR RATE BY CLASS")
print("="*80)
print(error_by_class.to_string(index=False))

# Visualize error rates
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

ax1.barh(error_by_class['class'], error_by_class['error_rate'], color='coral')
ax1.set_xlabel('Error Rate (%)')
ax1.set_title('Misclassification Rate by Class')
ax1.grid(axis='x', alpha=0.3)

ax2.barh(error_by_class['class'], error_by_class['misclassified'], color='steelblue')
ax2.set_xlabel('Number of Misclassifications')
ax2.set_title('Absolute Misclassifications by Class')
ax2.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.savefig('error_analysis_by_class_rnn.png', dpi=300, bbox_inches='tight')
plt.show()

# Confidence analysis
correct_mask = ~misclassified_mask
correct_confidence = y_pred_probs[correct_mask]
incorrect_confidence = y_pred_probs[misclassified_mask]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Histogram comparison
axes[0, 0].hist(correct_confidence, bins=30, alpha=0.7, label='Correct', color='green')
axes[0, 0].hist(incorrect_confidence, bins=30, alpha=0.7, label='Incorrect', color='red')
axes[0, 0].set_xlabel('Model Confidence (Probability)')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Model Confidence: Correct vs Incorrect Predictions')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# Boxplot comparison
axes[0, 1].boxplot([correct_confidence, incorrect_confidence], labels=['Correct', 'Incorrect'])
axes[0, 1].set_ylabel('Model Confidence')
axes[0, 1].set_title('Confidence Distribution Comparison')
axes[0, 1].grid(alpha=0.3)

# True class probability for misclassifications
true_class_probs = y_pred[np.arange(len(y_test)), y_test]
misclass_true_probs = true_class_probs[misclassified_indices]

axes[1, 0].hist(misclass_true_probs, bins=30, color='purple', alpha=0.7)
axes[1, 0].set_xlabel('True Class Probability')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('Probability Assigned to True Class (Misclassifications)')
axes[1, 0].grid(alpha=0.3)

# Confidence vs True probability scatter for misclassifications
axes[1, 1].scatter(incorrect_confidence, misclass_true_probs, alpha=0.5, c='red')
axes[1, 1].plot([0, 1], [0, 1], 'k--', alpha=0.3)
axes[1, 1].set_xlabel('Predicted Class Confidence')
axes[1, 1].set_ylabel('True Class Probability')
axes[1, 1].set_title('Confidence vs True Class Probability (Errors)')
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('confidence_analysis_rnn.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nAverage confidence for correct predictions: {correct_confidence.mean():.3f}")
print(f"Average confidence for incorrect predictions: {incorrect_confidence.mean():.3f}")
print(f"Average true class probability for misclassifications: {misclass_true_probs.mean():.3f}")

# Analyze low and high confidence errors
low_conf_threshold = 0.5
high_conf_threshold = 0.8

low_conf_errors = [i for i in misclassified_indices if y_pred_probs[i] < low_conf_threshold]
high_conf_errors = [i for i in misclassified_indices if y_pred_probs[i] > high_conf_threshold]

print(f"\nMisclassifications with low confidence (<{low_conf_threshold}): {len(low_conf_errors)}")
print(f"Misclassifications with high confidence (>{high_conf_threshold}): {len(high_conf_errors)}")

# Show sample misclassified texts
print("SAMPLE MISCLASSIFIED TEXTS (Top 3 Patterns)")
for idx, row in misclass_pairs.head(3).iterrows():
    true_label = row['true_label']
    pred_label = row['predicted_label']
    
    examples = misclass_df[
        (misclass_df['true_label'] == true_label) & 
        (misclass_df['predicted_label'] == pred_label)
    ].head(2)
    
    print(f"TRUE: {true_label} | PREDICTED: {pred_label} | Count: {row['count']}")
    
    for i, (_, example) in enumerate(examples.iterrows(), 1):
        text_preview = example['text'][:300].replace('\n', ' ')
        print(f"\nExample {i} (Confidence: {example['confidence']:.3f}):")
        print(f"{text_preview}...")
  
# High confidence errors deserve special attention
if len(high_conf_errors) > 0:
    print("HIGH CONFIDENCE ERRORS (Model was very confident but wrong)")
    
    for idx in high_conf_errors[:3]:
        text = X_test[idx] if isinstance(X_test[idx], str) else X_test[idx].decode('utf-8', errors='ignore')
        print(f"\nTrue: {class_names[y_test[idx]]} | Predicted: {class_names[y_pred_classes[idx]]}")
        print(f"Confidence: {y_pred_probs[idx]:.3f}")
        print(f"True class probability: {true_class_probs[idx]:.3f}")
        print(f"Text preview: {text[:300]}...")
  
# Sequence length analysis for misclassifications
def get_text_length(text):
    if isinstance(text, str):
        return len(text.split())
    return len(text.decode('utf-8', errors='ignore').split())

correct_lengths = [get_text_length(X_test[i]) for i in range(len(X_test)) if correct_mask[i]]
incorrect_lengths = [get_text_length(X_test[i]) for i in misclassified_indices]

plt.figure(figsize=(10, 5))
plt.hist(correct_lengths, bins=30, alpha=0.7, label='Correct', color='green')
plt.hist(incorrect_lengths, bins=30, alpha=0.7, label='Incorrect', color='red')
plt.xlabel('Document Length (words)')
plt.ylabel('Frequency')
plt.title('Document Length Distribution: Correct vs Incorrect')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('length_analysis_rnn.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nAverage length (words) - Correct: {np.mean(correct_lengths):.1f}")
print(f"Average length (words) - Incorrect: {np.mean(incorrect_lengths):.1f}")

In [None]:
model.save('rnn_classifier.keras')