In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam
from sklearn.metrics import accuracy_score
from dataset import News
from embedding import GloVe
from model import BiLSTM, AttentionBiLSTM
from utils import save_history, save_checkpoint, load_checkpoint, load_config
from train import train, evaluate

In [16]:
def main():
    config = load_config()
    print(config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    path = "data"
    epochs = 50
    
    history = {
        'train': {
            'loss': list(),
            'accuracy': list(),
        },
        'val': {
            'loss': list(),
            'accuracy': list(),
        }
    }
    
    glove = GloVe(**config['GloVe'])
    model = AttentionBiLSTM(glove, **config['Multi-Attention-BiLSTM'])
    model.to(device)
    optimizer = Adam(model.parameters(), **config['ADAM'])
    criterion = nn.CrossEntropyLoss()
    
    train_dataset = News(path, glove, split='train')
    val_dataset = News(path, glove, split='validation')
    test_dataset = News(path, glove, split='test')
    train_dataloader = DataLoader(train_dataset, collate_fn=train_dataset.collate_fn, **config['train'])
    val_dataloader = DataLoader(val_dataset, collate_fn=val_dataset.collate_fn, **config['validation'])
    test_dataloader = DataLoader(test_dataset, collate_fn=test_dataset.collate_fn, **config['test'])
    
    checkpoint = load_checkpoint()
    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        history = checkpoint['history']

    for epoch in range(epochs):
        train_loss, train_accuracy = train(model, train_dataloader, optimizer, criterion, device)
        val_loss, val_accuracy = evaluate(model, val_dataloader, criterion, device)
        
        is_best = True if len(history['val']['accuracy']) > 0 and val_accuracy > max(history['val']['accuracy']) else False
        
        history['train']['loss'].append(train_loss)
        history['train']['accuracy'].append(train_accuracy)
        history['val']['loss'].append(val_loss)
        history['val']['accuracy'].append(val_accuracy)
        
        if is_best:
            save_checkpoint(epoch, model, optimizer, history, path="assets/checkpoint_attention-multi-lstm.pt")
            
        print(f"Epoch {epoch + 1}/{epochs}")
        print(f"Loss: {train_loss:.4e}, Accuracy: {train_accuracy:6.2f}, Validation Loss: {val_loss:.4e}, Validation Accuracy: {val_accuracy:6.2f}")
    
    save_history(history, path="assets/history_attention-multi-lstm.pkl")
    
    test_loss, test_accuracy = evaluate(model, test_dataloader, criterion, device)
    print("Test - Final Model")
    print(f"Test Loss: {test_loss:.4e}, Test Accuracy: {test_accuracy:6.2f}")
    
    checkpoint = load_checkpoint(path="assets/checkpoint_attention-multi-lstm.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    test_loss, test_accuracy = evaluate(model, test_dataloader, criterion, device)
    print("Test - Best Model")
    print(f"Test Loss: {test_loss:.4e}, Test Accuracy: {test_accuracy:6.2f}")

In [17]:
main()

{'ADAM': {'betas': (0.9, 0.999), 'eps': 1e-08, 'lr': 0.001, 'weight_decay': 0.001}, 'SGD': {'lr': 0.01, 'weight_decay': 1e-05, 'momentum': 0.9}, 'GloVe': {'variant': 'glove-wiki-gigaword-300'}, 'LSTM': {'bidirectional': True, 'freeze': True, 'hidden_size': 64, 'dropout_rate': 0.3, 'n_classes': 3, 'num_layers': 1}, 'Multi-LSTM': {'bidirectional': True, 'freeze': True, 'hidden_size': 64, 'dropout_rate': 0.5, 'n_classes': 3, 'num_layers': 2}, 'Attention-LSTM': {'bidirectional': True, 'freeze': True, 'hidden_size': 64, 'dropout_rate': 0.5, 'n_classes': 3, 'num_layers': 1}, 'Attention-Multi-LSTM': {'bidirectional': True, 'freeze': True, 'hidden_size': 64, 'dropout_rate': 0.5, 'n_classes': 3, 'num_layers': 2}, 'train': {'batch_size': 32, 'shuffle': True}, 'validation': {'batch_size': 32, 'shuffle': False}, 'test': {'batch_size': 32, 'shuffle': False}}
Epoch 1/50
Loss: 8.2297e-01, Accuracy:   0.62, Validation Loss: 8.9650e-01, Validation Accuracy:   0.55
Epoch 2/50
Loss: 7.5343e-01, Accuracy: