In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import json

In [None]:
from helpers.acl_imdb_dataset import build_dataloader
from model import SentimentModel

In [None]:
BATCH_SIZE = 16
EPOCHS = 100
EARLY_STOP_PATIENCE = 5


In [None]:
train_dataset_path = "/home/admyyh/python_workspace/advml/rnn_sa/aclImdb/train"
test_dataset_path = "/home/admyyh/python_workspace/advml/rnn_sa/aclImdb/test"
vocab_path = "/home/admyyh/python_workspace/advml/rnn_sa/word_vocab_norm_None_stop_False.json"

train_loader, validation_loader = build_dataloader(
    root_folder=train_dataset_path, vocab_file=vocab_path,
    batch_size=32,  # Using batch size 8 for testing
    shuffle=True,
    split=True, val_ratio=0.1, random_seed=42,
)

test_loader = build_dataloader(
    root_folder=test_dataset_path, vocab_file=vocab_path,
    batch_size=32,  # Using batch size 8 for testing
    shuffle=False,
    split=False
)

In [None]:
with open(vocab_path, 'r') as f:
    vocab_data = json.load(f)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SentimentModel(vocab_size=len(vocab_data['word_vocab']))
model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()  # tag classification loss
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

In [None]:
best_val_loss = float('inf')

# Training loop
for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for inputs, targets, lengths in train_loader:
        inputs, targets, lengths = inputs.to(device), targets.to(device), lengths.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(inputs, lengths)
        
        # Calculate loss
        loss = criterion(logits, targets)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping (moved inside the loop)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update weights
        optimizer.step()
        
        # Track statistics
        train_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(logits, 1)
        train_total += targets.size(0)
        train_correct += (predicted == targets).sum().item()
    
    # Calculate average training metrics
    avg_train_loss = train_loss / len(train_loader.dataset)
    train_accuracy = 100 * train_correct / train_total
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for inputs, targets, lengths in validation_loader:
            inputs, targets, lengths = inputs.to(device), targets.to(device), lengths.to(device)
            
            # Forward pass
            logits = model(inputs, lengths)
            
            # Calculate loss
            loss = criterion(logits, targets)
            
            # Track statistics
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(logits, 1)
            val_total += targets.size(0)
            val_correct += (predicted == targets).sum().item()
    
    # Calculate average validation metrics
    avg_val_loss = val_loss / len(validation_loader.dataset)
    val_accuracy = 100 * val_correct / val_total
    
    # Print epoch results
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")
    
    # # Early stopping check
    # if avg_val_loss < best_val_loss:
    #     best_val_loss = avg_val_loss
    #     patience_counter = 0
    #     # Save the best model
    #     torch.save(model.state_dict(), 'best_sentiment_model.pt')
    #     print("  Saved best model!")
    # else:
    #     patience_counter += 1
    #     print(f"  No improvement for {patience_counter} epochs")
        
    # if patience_counter >= EARLY_STOP_PATIENCE:
    #     print(f"Early stopping triggered after {epoch+1} epochs")
    #     break
