# LSTM vs CNN for Toxic Comment Classification

Multi-label classification on Jigsaw dataset

In [1]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pickle
import time

from src.models import LSTMClassifier, SimpleCNN
from src.utils import train_epoch, evaluate, get_predictions, calculate_metrics
from src.train import clean_text, build_vocab, TextDataset
from torch.utils.data import DataLoader

In [2]:
# load data
df = pd.read_csv("../data/raw/train.csv")
print(f"Total samples: {len(df)}")

label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
print(f"\nLabel distribution:")
print(df[label_cols].sum())

Total samples: 159571

Label distribution:
toxic            15294
severe_toxic      1595
obscene           8449
threat             478
insult            7877
identity_hate     1405
dtype: int64


In [3]:
# check multi-label distribution
df['num_labels'] = df[label_cols].sum(axis=1)
print("Number of labels per comment:")
print(df['num_labels'].value_counts().sort_index())

Number of labels per comment:
num_labels
0    143346
1      6360
2      3480
3      4209
4      1760
5       385
6        31
Name: count, dtype: int64


In [4]:
# preprocess
df["clean_text"] = df["comment_text"].apply(clean_text)

# check samples
print("Sample comments:")
for i in range(2):
    print(f"\n{i+1}. {df['clean_text'].iloc[i][:100]}...")
    labels = [label_cols[j] for j in range(len(label_cols)) if df[label_cols[j]].iloc[i] == 1]
    print(f"   Labels: {labels if labels else ['clean']}")

Sample comments:

1. explanation why the edits made under my username hardcore metallica fan were reverted they werent va...
   Labels: ['clean']

2. daww he matches this background colour im seemingly stuck with thanks talk january utc...
   Labels: ['clean']


In [5]:
# split data
X_train, X_test, y_train, y_test = train_test_split(
    df["clean_text"].values, df[label_cols].values,
    test_size=0.2, random_state=42
)

print(f"Train: {len(X_train)}, Test: {len(X_test)}")

Train: 127656, Test: 31915


In [6]:
# build vocab
vocab = build_vocab(X_train, max_vocab=10000)
print(f"Vocab size: {len(vocab)}")

Vocab size: 10000


In [7]:
# create datasets
MAX_LEN = 100
BATCH_SIZE = 64

train_dataset = TextDataset(X_train, y_train, vocab, max_len=MAX_LEN)
test_dataset = TextDataset(X_test, y_test, vocab, max_len=MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

## LSTM Model

Bidirectional LSTM for sequence modeling

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

lstm_model = LSTMClassifier(
    vocab_size=len(vocab),
    embedding_dim=100,
    hidden_dim=128,
    output_dim=len(label_cols),
    n_layers=2,
    dropout=0.3
).to(device)

print(f"\nModel params: {sum(p.numel() for p in lstm_model.parameters())}")

Device: cpu

Model params: 1632326


In [9]:
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

n_epochs = 10
lstm_train_losses = []
lstm_test_losses = []
lstm_train_accs = []
lstm_test_accs = []

start_time = time.time()

for epoch in range(n_epochs):
    train_loss, train_acc = train_epoch(lstm_model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(lstm_model, test_loader, criterion, device)
    
    lstm_train_losses.append(train_loss)
    lstm_test_losses.append(test_loss)
    lstm_train_accs.append(train_acc)
    lstm_test_accs.append(test_acc)
    
    print(f'Epoch {epoch+1:02d} | Train Loss: {train_loss:.3f} Acc: {train_acc:.3f} | Test Loss: {test_loss:.3f} Acc: {test_acc:.3f}')

lstm_train_time = time.time() - start_time
print(f"\nTraining time: {lstm_train_time:.2f}s")

Epoch 01 | Train Loss: 0.087 Acc: 0.973 | Test Loss: 0.059 Acc: 0.980
Epoch 02 | Train Loss: 0.058 Acc: 0.980 | Test Loss: 0.052 Acc: 0.982


KeyboardInterrupt: 

In [None]:
# evaluate
y_pred_lstm, y_true_lstm = get_predictions(lstm_model, test_loader, device)
metrics_lstm = calculate_metrics(y_true_lstm, y_pred_lstm, label_cols)

print("LSTM Overall Results:")
for k, v in metrics_lstm['overall'].items():
    print(f"{k}: {v:.4f}")

print("\nPer-label F1 scores:")
for label in label_cols:
    print(f"{label}: {metrics_lstm[label]['f1']:.4f}")

## CNN Model

Multi-filter CNN for local pattern detection

In [None]:
cnn_model = SimpleCNN(
    vocab_size=len(vocab),
    embedding_dim=100,
    n_filters=100,
    filter_sizes=[3, 4, 5],
    output_dim=len(label_cols),
    dropout=0.5
).to(device)

print(f"Model params: {sum(p.numel() for p in cnn_model.parameters())}")

In [None]:
optimizer_cnn = torch.optim.Adam(cnn_model.parameters(), lr=0.001)

cnn_train_losses = []
cnn_test_losses = []
cnn_train_accs = []
cnn_test_accs = []

start_time = time.time()

for epoch in range(n_epochs):
    train_loss, train_acc = train_epoch(cnn_model, train_loader, optimizer_cnn, criterion, device)
    test_loss, test_acc = evaluate(cnn_model, test_loader, criterion, device)
    
    cnn_train_losses.append(train_loss)
    cnn_test_losses.append(test_loss)
    cnn_train_accs.append(train_acc)
    cnn_test_accs.append(test_acc)
    
    print(f'Epoch {epoch+1:02d} | Train Loss: {train_loss:.3f} Acc: {train_acc:.3f} | Test Loss: {test_loss:.3f} Acc: {test_acc:.3f}')

cnn_train_time = time.time() - start_time
print(f"\nTraining time: {cnn_train_time:.2f}s")

In [None]:
# evaluate
y_pred_cnn, y_true_cnn = get_predictions(cnn_model, test_loader, device)
metrics_cnn = calculate_metrics(y_true_cnn, y_pred_cnn, label_cols)

print("CNN Overall Results:")
for k, v in metrics_cnn['overall'].items():
    print(f"{k}: {v:.4f}")

print("\nPer-label F1 scores:")
for label in label_cols:
    print(f"{label}: {metrics_cnn[label]['f1']:.4f}")

## Comparison

In [None]:
# plot training curves
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

axes[0, 0].plot(lstm_train_losses, label='Train', alpha=0.7)
axes[0, 0].plot(lstm_test_losses, label='Test', alpha=0.7)
axes[0, 0].set_title('LSTM Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

axes[0, 1].plot(lstm_train_accs, label='Train', alpha=0.7)
axes[0, 1].plot(lstm_test_accs, label='Test', alpha=0.7)
axes[0, 1].set_title('LSTM Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

axes[1, 0].plot(cnn_train_losses, label='Train', alpha=0.7)
axes[1, 0].plot(cnn_test_losses, label='Test', alpha=0.7)
axes[1, 0].set_title('CNN Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

axes[1, 1].plot(cnn_train_accs, label='Train', alpha=0.7)
axes[1, 1].plot(cnn_test_accs, label='Test', alpha=0.7)
axes[1, 1].set_title('CNN Accuracy')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('../outputs/lstm_cnn_training.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# comparison table
results = pd.DataFrame({
    'Model': ['LSTM', 'CNN'],
    'Accuracy': [metrics_lstm['overall']['accuracy'], metrics_cnn['overall']['accuracy']],
    'Precision': [metrics_lstm['overall']['precision'], metrics_cnn['overall']['precision']],
    'Recall': [metrics_lstm['overall']['recall'], metrics_cnn['overall']['recall']],
    'F1': [metrics_lstm['overall']['f1'], metrics_cnn['overall']['f1']],
    'Train Time (s)': [lstm_train_time, cnn_train_time]
})

print(results.to_string(index=False))

In [None]:
# per-label comparison
per_label_f1 = pd.DataFrame({
    'Label': label_cols,
    'LSTM F1': [metrics_lstm[l]['f1'] for l in label_cols],
    'CNN F1': [metrics_cnn[l]['f1'] for l in label_cols]
})

print(per_label_f1.to_string(index=False))

In [None]:
# save models
torch.save(lstm_model.state_dict(), '../outputs/lstm_model.pt')
torch.save(cnn_model.state_dict(), '../outputs/cnn_model.pt')

with open('../outputs/vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)
with open('../outputs/label_cols.pkl', 'wb') as f:
    pickle.dump(label_cols, f)

print("Models saved")

## Notes

**LSTM observations:**
- [Add your observations after running]

**CNN observations:**
- [Add your observations after running]

**Multi-label challenges:**
- Class imbalance (some labels are rare)
- Label correlation (toxic often appears with insult)
- Threshold selection matters

**Next steps:**
- Try ensemble methods
- Test BERT
- Experiment with class weights