# DistilBERT + TextRCNN Pipeline
## Advanced Document Classification with Full Parameter Finetuning

This notebook implements DistilBERT + TextRCNN architecture for document authenticity classification kaggle competition Fake or Real: The Impostor Hunt in Texts.

## Setup and Imports

In [1]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Set device
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using MPS (Metal GPU) device')
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print('Using CUDA device')
else:
    device = torch.device('cpu')
    print('Using CPU device')

print(f'Device: {device}')

Using MPS (Metal GPU) device
Device: mps


## TextRCNN Architecture

In [2]:
class TextRCNN(nn.Module):
    """TextRCNN: RNN + CNN + Attention for text classification."""
    
    def __init__(self, hidden_size=768, num_layers=2, num_classes=2, dropout=0.3):
        super(TextRCNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        
        # Bidirectional LSTM
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size // 2,  # Bidirectional will double this
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # CNN layers for local feature extraction
        self.conv1 = nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1)
        
        # Attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=dropout,
            batch_first=True
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),  # LSTM output is already hidden_size
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, num_classes)
        )
        
        # Layer normalization
        self.layer_norm1 = nn.LayerNorm(hidden_size)
        self.layer_norm2 = nn.LayerNorm(hidden_size)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, bert_outputs, attention_mask=None):
        # bert_outputs: [batch_size, seq_len, hidden_size]
        batch_size, seq_len, hidden_size = bert_outputs.shape
        
        # Bidirectional LSTM
        lstm_out, _ = self.lstm(bert_outputs)
        # lstm_out: [batch_size, seq_len, hidden_size]
        
        # CNN feature extraction
        # Transpose for CNN: [batch_size, hidden_size, seq_len]
        cnn_input = lstm_out.transpose(1, 2)
        
        # Apply CNN layers with residual connections
        conv1_out = F.relu(self.conv1(cnn_input))
        conv2_out = F.relu(self.conv2(conv1_out))
        conv3_out = F.relu(self.conv3(conv2_out))
        
        # Transpose back: [batch_size, seq_len, hidden_size]
        cnn_out = conv3_out.transpose(1, 2)
        
        # Add residual connection
        cnn_out = self.layer_norm1(cnn_out + lstm_out)
        
        # Selfattention mechanism (simplified to avoid mask issues)
        # Use simple attention without complex masking
        attn_out, _ = self.attention(cnn_out, cnn_out, cnn_out)
        
        # Add residual connection
        attn_out = self.layer_norm2(attn_out + cnn_out)
        
        # Global average pooling
        if attention_mask is not None:
            # Masked average pooling
            masked_output = attn_out * attention_mask.unsqueeze(-1)
            pooled_output = masked_output.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
        else:
            pooled_output = attn_out.mean(dim=1)
        
        # Classification
        logits = self.classifier(pooled_output)
        
        return logits

## Complete DistilBERT + TextRCNN Classifier

In [3]:
class DistilBertTextRCnnClassifier(nn.Module):
    """Complete DistilBERT + TextRCNN classifier."""
    
    def __init__(self, model_name='distilbertbaseuncased', num_classes=2, dropout=0.3):
        super(DistilBertTextRCnnClassifier, self).__init__()
        
        # DistilBERT encoder
        self.bert = DistilBertModel.from_pretrained(model_name)
        
        # TextRCNN classifier
        self.textrcnn = TextRCNN(
            hidden_size=768,  # DistilBERT hidden size
            num_layers=2,
            num_classes=num_classes,
            dropout=dropout
        )
        
        # Freeze BERT initially (will unfreeze during finetuning)
        self.freeze_bert()
        
    def freeze_bert(self):
        """Freeze BERT parameters."""
        for param in self.bert.parameters():
            param.requires_grad = False
        print("BERT parameters frozen")
    
    def unfreeze_bert(self):
        """Unfreeze BERT parameters for finetuning."""
        for param in self.bert.parameters():
            param.requires_grad = True
        print("BERT parameters unfrozen for finetuning")
    
    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        bert_outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state
        
        # Pass through TextRCNN
        logits = self.textrcnn(bert_outputs, attention_mask)
        
        return logits

## Custom Dataset

In [4]:
class DocumentDataset(Dataset):
    """Custom dataset for document classification."""
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        label = item['label']
        
        # Tokenize text
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

## Data Loading Functions

In [5]:
def load_document_pairs(data_dir):
    """Load document pairs from the data directory."""
    pairs = []
    
    for article_dir in sorted(os.listdir(data_dir)):
        article_path = os.path.join(data_dir, article_dir)
        if os.path.isdir(article_path):
            files = [f for f in os.listdir(article_path) if f.endswith('.txt')]
            if len(files) == 2:
                file1_path = os.path.join(article_path, files[0])
                file2_path = os.path.join(article_path, files[1])
                
                with open(file1_path, 'r', encoding='utf8') as f:
                    content1 = f.read().strip()
                with open(file2_path, 'r', encoding='utf8') as f:
                    content2 = f.read().strip()
                
                pairs.append({
                    'article_id': article_dir,
                    'file1': files[0],
                    'file2': files[1],
                    'content1': content1,
                    'content2': content2,
                    'file1_path': file1_path,
                    'file2_path': file2_path
                })
    
    return pairs

def create_training_data(pairs, labels_df):
    """Create training data with one text per row and proper labels."""
    
    print(f"Creating training data with {len(labels_df)} articles...")
    
    training_data = []
    
    for _, row in labels_df.iterrows():
        article_id = row['id']
        real_text_id = row['real_text_id']
        
        article_folder = f"article_{str(article_id).zfill(4)}"
        pair = None
        
        for p in pairs:
            if p['article_id'] == article_folder:
                pair = p
                break
        
        if pair is None:
            print(f"Warning: Could not find pair for article {article_id}")
            continue
        
        if real_text_id == 1:
            real_content = pair['content1']
            fake_content = pair['content2']
        else:
            real_content = pair['content2']
            fake_content = pair['content1']
        
        training_data.append({
            'text': real_content,
            'label': 1,
            'article_id': article_id,
            'text_type': 'real'
        })
        
        training_data.append({
            'text': fake_content,
            'label': 0,
            'article_id': article_id,
            'text_type': 'fake'
        })
    
    print(f"Created {len(training_data)} training examples")
    print(f"    Real documents: {len([x for x in training_data if x['label'] == 1])}")
    print(f"    Fake documents: {len([x for x in training_data if x['label'] == 0])}")
    
    return training_data

## Training Function

In [6]:
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=2e-5):
    """Train the DistilBERT + TextRCNN model with improved validation."""
    
    print(" Training DistilBERT + TextRCNN model...")
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
    
    # Learning rate scheduler
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )
    
    # Training loop
    best_val_acc = 0
    patience = 5
    patience_counter = 0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_acc = correct / total
        avg_loss = total_loss / len(train_loader)
        
        # Validation phase with better debugging
        model.eval()
        val_correct = 0
        val_total = 0
        val_loss = 0
        val_predictions = []
        val_true_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                # Store predictions for debugging
                val_predictions.extend(predicted.cpu().numpy())
                val_true_labels.extend(labels.cpu().numpy())
        
        val_acc = val_correct / val_total if val_total > 0 else 0
        avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else 0
        
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Training Loss: {avg_loss:.4f}, Training Acc: {train_acc:.4f}")
        print(f"  Validation Loss: {avg_val_loss:.4f}, Validation Acc: {val_acc:.4f}")
        
        # Debug validation predictions
        if epoch == 0 or val_acc == 0:
            print(f"   Validation Debug:")
            print(f"      Total validation samples: {val_total}")
            print(f"      Correct predictions: {val_correct}")
            print(f"      Prediction distribution: {np.bincount(val_predictions) if len(val_predictions) > 0 else 'No predictions'}")
            print(f"      True label distribution: {np.bincount(val_true_labels) if len(val_true_labels) > 0 else 'No labels'}")
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), 'best_distilbert_textrcnn_model.pth')
            print(f"   New best validation accuracy: {best_val_acc:.4f}")
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"    Early stopping after {patience} epochs without improvement")
            break
        
        # Unfreeze BERT after first few epochs for finetuning
        if epoch == 2:
            model.unfreeze_bert()
            print("   Unfrozen BERT for finetuning")
    
    print(f"\n Best validation accuracy: {best_val_acc:.4f}")
    print(" Training completed!")
    
    return best_val_acc

##  Validation Debugging

This cell helps debug validation issues if they occur.

In [7]:
# Validation debugging helper
def debug_validation(model, val_loader, device):
    """Debug validation issues."""
    model.eval()
    val_correct = 0
    val_total = 0
    val_predictions = []
    val_true_labels = []
    
    print(" Validation Debugging...")
    
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            if i >= 3:  # Only check first 3 batches
                break
                
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            print(f"\nBatch {i+1}:")
            print(f"  Input shape: {input_ids.shape}")
            print(f"  Labels: {labels}")
            
            try:
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                print(f"  Output shape: {outputs.shape}")
                print(f"  Output values: {outputs[0]}")
                
                _, predicted = torch.max(outputs.data, 1)
                print(f"  Predictions: {predicted}")
                
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                val_predictions.extend(predicted.cpu().numpy())
                val_true_labels.extend(labels.cpu().numpy())
                
            except Exception as e:
                print(f"   Error: {e}")
    
    print(f"\n Summary:")
    print(f"  Total samples checked: {val_total}")
    print(f"  Correct predictions: {val_correct}")
    print(f"  Prediction distribution: {np.bincount(val_predictions) if len(val_predictions) > 0 else 'No predictions'}")
    print(f"  True label distribution: {np.bincount(val_true_labels) if len(val_true_labels) > 0 else 'No labels'}")
    
    if val_total > 0:
        accuracy = val_correct / val_total
        print(f"  Accuracy: {accuracy:.4f}")
    
    return val_predictions, val_true_labels

# You can call this function if validation fails:
# debug_validation(model, val_loader, device)

## Prediction Function

In [8]:
def predict_test_set(model, tokenizer, test_pairs, max_length=512):
    """Generate predictions on test set."""
    
    print(" Generating test predictions...")
    
    model.eval()
    predictions = []
    
    for i, pair in enumerate(test_pairs):
        article_id = pair['article_id']
        try:
            numeric_id = int(article_id.split('_')[1])
            solution_id = numeric_id
        except (IndexError, ValueError):
            solution_id = i
        
        text1 = pair['content1']
        text2 = pair['content2']
        
        # Tokenize texts
        encoding1 = tokenizer(
            text1,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        
        encoding2 = tokenizer(
            text2,
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )
        
        # Move to device
        input_ids1 = encoding1['input_ids'].to(device)
        attention_mask1 = encoding1['attention_mask'].to(device)
        input_ids2 = encoding2['input_ids'].to(device)
        attention_mask2 = encoding2['attention_mask'].to(device)
        
        # Get predictions
        with torch.no_grad():
            outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1)
            outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2)
            
            probs1 = F.softmax(outputs1, dim=1)
            probs2 = F.softmax(outputs2, dim=1)
            
            pred1 = torch.argmax(outputs1, dim=1).item()
            pred2 = torch.argmax(outputs2, dim=1).item()
            
            real_prob1 = probs1[0][1].item()
            real_prob2 = probs2[0][1].item()
        
        # Determine which file is real
        if pred1 == 1 and pred2 == 0:
            real_text_id = 1
        elif pred1 == 0 and pred2 == 1:
            real_text_id = 2
        else:
            # Use probability
            real_text_id = 1 if real_prob1 > real_prob2 else 2
        
        predictions.append({
            'id': solution_id,
            'real_text_id': real_text_id,
            'text1_pred': pred1,
            'text2_pred': pred2,
            'text1_real_prob': real_prob1,
            'text2_real_prob': real_prob2
        })
        
        if (i + 1) % 100 == 0:
            print(f"Processed {i + 1}/{len(test_pairs)} pairs...")
    
    return predictions

## Main Pipeline Execution

In [9]:
# 1. Load data
print(" Step 1: Loading data...")
labels_df = pd.read_csv('train.csv')
train_pairs = load_document_pairs('train')
test_pairs = load_document_pairs('test')

print(f"Loaded {len(labels_df)} training articles")
print(f"Loaded {len(train_pairs)} training pairs")
print(f"Loaded {len(test_pairs)} test pairs")

 Step 1: Loading data...
Loaded 95 training articles
Loaded 95 training pairs
Loaded 1068 test pairs


In [10]:
# 2. Create training data
print(" Step 2: Creating training data...")
train_data = create_training_data(train_pairs, labels_df)

 Step 2: Creating training data...
Creating training data with 95 articles...
Created 190 training examples
    Real documents: 95
    Fake documents: 95


In [13]:
# 3. Initialize tokenizer and model
print("Step 3: Initializing DistilBERT + TextRCNN...")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertTextRCnnClassifier(
    model_name='distilbert-base-uncased',
    num_classes=2,
    dropout=0.3
).to(device)

print(f"Model initialized on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Step 3: Initializing DistilBERT + TextRCNN...
BERT parameters frozen
Model initialized on mps
Total parameters: 82,015,874
Trainable parameters: 15,652,994


In [14]:
# 4. Create datasets and data loaders
print("Step 4: Creating datasets...")

# Create DocumentDataset instances
train_dataset = DocumentDataset(train_data, tokenizer)
val_dataset = DocumentDataset(train_data, tokenizer)  # Use full data for validation

# Split the dataset
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    train_dataset, [train_size, val_size]
)

# Create data loaders
batch_size = 8  # Adjust based on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

Step 4: Creating datasets...
Training batches: 19
Validation batches: 5


In [15]:
# 5. Train the model
print(" Step 5: Training model...")
try:
    best_val_acc = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=15,
        learning_rate=2e-5
    )
    print(f" Training completed successfully! Best validation accuracy: {best_val_acc:.4f}")
except Exception as e:
    print(f" Training failed: {e}")
    print(" Running validation debug...")
    try:
        debug_validation(model, val_loader, device)
    except:
        print("  Debug function not available")
    best_val_acc = 0.0


 Step 5: Training model...
 Training DistilBERT + TextRCNN model...
Epoch 1/15:
  Training Loss: 0.6889, Training Acc: 0.5724
  Validation Loss: 0.6652, Validation Acc: 0.7368
   Validation Debug:
      Total validation samples: 38
      Correct predictions: 28
      Prediction distribution: [ 7 31]
      True label distribution: [17 21]
   New best validation accuracy: 0.7368
Epoch 2/15:
  Training Loss: 0.6438, Training Acc: 0.7039
  Validation Loss: 0.6128, Validation Acc: 0.7632
   New best validation accuracy: 0.7632
Epoch 3/15:
  Training Loss: 0.5896, Training Acc: 0.7697
  Validation Loss: 0.5554, Validation Acc: 0.7632
BERT parameters unfrozen for finetuning
   Unfrozen BERT for finetuning
Epoch 4/15:
  Training Loss: 0.4989, Training Acc: 0.8026
  Validation Loss: 0.4497, Validation Acc: 0.7895
   New best validation accuracy: 0.7895
Epoch 5/15:
  Training Loss: 0.3891, Training Acc: 0.8487
  Validation Loss: 0.3905, Validation Acc: 0.7895
Epoch 6/15:
  Training Loss: 0.2755,

In [16]:
# 6. Load best model
print("Step 6: Loading best model")
model.load_state_dict(torch.load('best_distilbert_textrcnn_model.pth'))
print("Best model loaded")

Step 6: Loading best model
Best model loaded


In [17]:
# 7. Generate predictions
print(" Step 7: Generating test predictions...")
predictions = predict_test_set(model, tokenizer, test_pairs)

 Step 7: Generating test predictions...
 Generating test predictions...
Processed 100/1068 pairs...
Processed 200/1068 pairs...
Processed 300/1068 pairs...
Processed 400/1068 pairs...
Processed 500/1068 pairs...
Processed 600/1068 pairs...
Processed 700/1068 pairs...
Processed 800/1068 pairs...
Processed 900/1068 pairs...
Processed 1000/1068 pairs...


In [18]:
# 8. Create solution file
solution_df = pd.DataFrame(predictions)
submission_df = solution_df[['id', 'real_text_id']].copy()

submission_df = submission_df.sort_values('id').reset_index(drop=True)
submission_df['id'] = submission_df['id'].astype(int)
submission_df['real_text_id'] = submission_df['real_text_id'].astype(int)

solution_file = 'Hunt_In_Text_Solution.csv'
submission_df.to_csv(solution_file, index=False)

print(f"Solution file saved as: {solution_file}")

Solution file saved as: Hunt_In_Text_Solution.csv
