In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import json

In [6]:
from helpers.acl_imdb_dataset import build_dataloader
from model import SentimentModel

In [7]:
BATCH_SIZE = 16
EPOCHS = 100
EARLY_STOP_PATIENCE = 5


In [8]:
train_dataset_path = "/home/admyyh/python_workspace/advml/rnn_sa/aclImdb/train"
test_dataset_path = "/home/admyyh/python_workspace/advml/rnn_sa/aclImdb/test"
vocab_path = "/home/admyyh/python_workspace/advml/rnn_sa/word_vocab_norm_None_stop_False.json"

train_loader, validation_loader = build_dataloader(
    root_folder=train_dataset_path, vocab_file=vocab_path,
    batch_size=32,  # Using batch size 8 for testing
    shuffle=True,
    split=True, val_ratio=0.1, random_seed=42,
)

test_loader = build_dataloader(
    root_folder=test_dataset_path, vocab_file=vocab_path,
    batch_size=32,  # Using batch size 8 for testing
    shuffle=False,
    split=False
)

Dataset split: 22500 training samples, 2500 validation samples


In [9]:
with open(vocab_path, 'r') as f:
    vocab_data = json.load(f)


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SentimentModel(vocab_size=len(vocab_data['word_vocab']))
model.to(device)

SentimentModel(
  (embed): Embedding(99928, 64, padding_idx=0)
  (dropout_embed): Dropout(p=0.5, inplace=False)
  (lstm): LSTM(64, 64, batch_first=True)
  (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (dropout_lstm): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=64, out_features=2, bias=True)
)

In [11]:
criterion = nn.CrossEntropyLoss()  # tag classification loss
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

In [12]:
best_val_loss = float('inf')

# Training loop
for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for inputs, targets, lengths in train_loader:
        inputs, targets, lengths = inputs.to(device), targets.to(device), lengths.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(inputs, lengths)
        
        # Calculate loss
        loss = criterion(logits, targets)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping (moved inside the loop)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        optimizer.step()
        
        # Track statistics
        train_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(logits, 1)
        train_total += targets.size(0)
        train_correct += (predicted == targets).sum().item()
    
    # Calculate average training metrics
    avg_train_loss = train_loss / len(train_loader.dataset)
    train_accuracy = 100 * train_correct / train_total
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for inputs, targets, lengths in validation_loader:
            inputs, targets, lengths = inputs.to(device), targets.to(device), lengths.to(device)
            
            # Forward pass
            logits = model(inputs, lengths)
            
            # Calculate loss
            loss = criterion(logits, targets)
            
            # Track statistics
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(logits, 1)
            val_total += targets.size(0)
            val_correct += (predicted == targets).sum().item()
    
    # Calculate average validation metrics
    avg_val_loss = val_loss / len(validation_loader.dataset)
    val_accuracy = 100 * val_correct / val_total
    
    # Print epoch results
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")
    
    # # Early stopping check
    # if avg_val_loss < best_val_loss:
    #     best_val_loss = avg_val_loss
    #     patience_counter = 0
    #     # Save the best model
    #     torch.save(model.state_dict(), 'best_sentiment_model.pt')
    #     print("  Saved best model!")
    # else:
    #     patience_counter += 1
    #     print(f"  No improvement for {patience_counter} epochs")
        
    # if patience_counter >= EARLY_STOP_PATIENCE:
    #     print(f"Early stopping triggered after {epoch+1} epochs")
    #     break


Epoch 1/100
  Train Loss: 0.6949 | Train Acc: 53.53%
  Val Loss: 0.6626 | Val Acc: 61.64%
Epoch 2/100
  Train Loss: 0.6283 | Train Acc: 66.35%
  Val Loss: 0.5641 | Val Acc: 74.84%
Epoch 3/100
  Train Loss: 0.5714 | Train Acc: 72.65%
  Val Loss: 0.6788 | Val Acc: 67.60%
Epoch 4/100
  Train Loss: 0.5238 | Train Acc: 76.26%
  Val Loss: 0.5407 | Val Acc: 75.60%
Epoch 5/100
  Train Loss: 0.4472 | Train Acc: 80.99%
  Val Loss: 0.6122 | Val Acc: 79.76%
Epoch 6/100
  Train Loss: 0.3963 | Train Acc: 83.50%
  Val Loss: 0.4857 | Val Acc: 82.40%
Epoch 7/100
  Train Loss: 0.3578 | Train Acc: 85.41%
  Val Loss: 0.4175 | Val Acc: 85.36%
Epoch 8/100
  Train Loss: 0.3220 | Train Acc: 87.24%
  Val Loss: 0.4394 | Val Acc: 86.04%
Epoch 9/100
  Train Loss: 0.2846 | Train Acc: 89.11%
  Val Loss: 0.3947 | Val Acc: 88.04%
Epoch 10/100
  Train Loss: 0.2538 | Train Acc: 90.33%
  Val Loss: 0.5191 | Val Acc: 86.92%
Epoch 11/100
  Train Loss: 0.2291 | Train Acc: 91.51%
  Val Loss: 0.4442 | Val Acc: 87.92%
Epoch 12

KeyboardInterrupt: 