In [1]:
import sys
sys.path.append('..')
import configs.common as cc
import configs.mamba as cm
import models
import processing
import torch.optim as optim
import torch.nn as nn
import torch

In [2]:
train_dataloader, test_dataloader = processing.get_train_test_dataloaders('F:\\GitHub\\dataset\\np_dataset')
METADATA_VOCAB_SIZE = processing.get_metadata_vocab_size()

In [4]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=cc.config.values.learning_rate)

# Training loop
num_epochs = cc.config.values.epochs
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_loss = 0

    for batch_idx, (src, trg, metadata) in enumerate(train_dataloader):
        # Forward pass
        output = model(src)
        # print(output.shape)
        # Reshape output and target for loss calculation
        output = output.reshape(-1, cc.vocab_size)  # Flatten the output to [batch_size * seq_len, vocab_size]
        trg = trg.view(-1)  # Flatten the target to [batch_size * seq_len]

        # Compute loss
        loss = criterion(output, trg)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        if (batch_idx + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}')

    avg_loss = total_loss / len(train_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

    # Validation loop (optional)
    model.eval()  # Set the model to evaluation mode
    val_loss = 0
    with torch.no_grad():
        for src, trg, metadata in test_dataloader:
            src, trg = src.to(cc.config.values.device), trg.to(cc.config.values.device)
            output = model(src)
            output = output.reshape(-1, cc.vocab_size)
            trg = trg.view(-1)
            val_loss += criterion(output, trg).item()
    
    avg_val_loss = val_loss / len(test_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}')

print("Training complete!")

Epoch [1/200], Step [10/14], Loss: 5.3650
Epoch [1/200], Average Loss: 5.7254
Epoch [1/200], Validation Loss: 4.8195
Epoch [2/200], Step [10/14], Loss: 3.6662
Epoch [2/200], Average Loss: 3.8624
Epoch [2/200], Validation Loss: 4.0298
Epoch [3/200], Step [10/14], Loss: 3.2212
Epoch [3/200], Average Loss: 3.3043
Epoch [3/200], Validation Loss: 3.4899
Epoch [4/200], Step [10/14], Loss: 2.4423
Epoch [4/200], Average Loss: 2.8837
Epoch [4/200], Validation Loss: 3.3795
Epoch [5/200], Step [10/14], Loss: 3.0231
Epoch [5/200], Average Loss: 2.8595
Epoch [5/200], Validation Loss: 3.4604
Epoch [6/200], Step [10/14], Loss: 2.9644
Epoch [6/200], Average Loss: 2.6939
Epoch [6/200], Validation Loss: 2.8522
Epoch [7/200], Step [10/14], Loss: 2.6556
Epoch [7/200], Average Loss: 2.6921
Epoch [7/200], Validation Loss: 2.8363
Epoch [8/200], Step [10/14], Loss: 2.4233
Epoch [8/200], Average Loss: 2.6693
Epoch [8/200], Validation Loss: 3.1781
Epoch [9/200], Step [10/14], Loss: 2.5328
Epoch [9/200], Average

KeyboardInterrupt: 

In [8]:
output.argmax(-2)[:40]

tensor([821, 821, 821, 821, 821, 821, 681, 821, 821, 821, 759, 821, 821, 893,
        912, 893, 524, 420, 280, 420, 352, 512, 512, 420, 420, 518, 512, 420,
        512, 512, 280, 510, 512, 512, 510, 512, 450, 512, 384, 512],
       device='cuda:0')