In [None]:
from config import *
import processing
import models
import torch.optim as optim
import torch.nn as nn
import numpy as np
import math
import torch
import json


In [None]:
train_dataloader, test_dataloader = processing.get_train_test_dataloaders('F:\\GitHub\\dataset\\np_dataset')
with open('F:\\GitHub\\dataset\\midi_dataset\\tokenizations.json', 'r') as f:
    tokenizations = json.load(f)
METADATA_VOCAB_SIZE = tokenizations['VOCAB_SIZE']

In [None]:

# vocab_size, n_embd, n_layer, n_heads, block_size, dropout, device
model = models.Transformer(VOCAB_SIZE, N_EMBD, N_LAYER, N_HEAD, BLOCK_SIZE, DROPOUT, DEVICE)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training loop
num_epochs = EPOCHS
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_loss = 0

    for batch_idx, (src, trg, metadata) in enumerate(train_dataloader):
        # Forward pass
        output = model(src)
        # print(output.shape)
        # Reshape output and target for loss calculation
        output = output.reshape(-1, VOCAB_SIZE)  # Flatten the output to [batch_size * seq_len, vocab_size]
        trg = trg.view(-1)  # Flatten the target to [batch_size * seq_len]

        # Compute loss
        loss = criterion(output, trg)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        if (batch_idx + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}')

    avg_loss = total_loss / len(train_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

    # Validation loop (optional)
    model.eval()  # Set the model to evaluation mode
    val_loss = 0
    with torch.no_grad():
        for src, trg, metadata in test_dataloader:
            src, trg = src.to(DEVICE), trg.to(DEVICE)
            output = model(src, metadata)
            output = output.reshape(-1, VOCAB_SIZE)
            trg = trg.view(-1)
            val_loss += criterion(output, trg).item()
    
    avg_val_loss = val_loss / len(test_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}')

print("Training complete!")