# Jetzt mit allen 88 Tasten

In [1]:
# imports
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np

In [2]:
# Check if GPU is available, set device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [3]:
from data_preperation import dataset_snapshot
from transformer_decoder_training.dataprep_transformer import dataprep_1
from sklearn.model_selection import train_test_split

#load data
dataset_as_snapshots = dataset_snapshot.process_dataset_multithreaded("/home/falaxdb/Repos/minus1/datasets/maestro_v3_split/hands_split_into_seperate_midis", 0.05)
# filter snapshots to 88 piano notes
dataset_as_snapshots = dataset_snapshot.filter_piano_range(dataset_as_snapshots)




Processed dataset (1038/1038): 100%|██████████| 1038/1038 [00:14<00:00, 73.11it/s]


Processed 1038 of 1038 files


In [4]:
# split songs into train, test and val
train_data, temp_data = train_test_split(dataset_as_snapshots, test_size=0.3, random_state=42, shuffle=True)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42, shuffle=True)

In [5]:
# Define special Tokens
# Token dimension needs to fit Data
sos_token = np.full((1, 176), 1)
pad_token = np.full((1, 176), 2)
pad_token = torch.tensor(pad_token, device=device)

# Define other parameters
batch_size = 64
seq_length = 512
stride = 256

In [6]:
# create dataset + dataloader
from torch.utils.data import DataLoader
from transformer_decoder_training.dataset_transformer.dataset_2 import AdvancedPianoDataset

train_dataset = AdvancedPianoDataset(train_data, seq_length, stride, sos_token)
val_dataset = AdvancedPianoDataset(val_data, seq_length, stride, sos_token)
test_dataset = AdvancedPianoDataset(test_data, seq_length, stride, sos_token)

# Create DataLoaders for each subset with drop_last=True
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [7]:
# initialize model

# set parameters
# Learning rate for the optimizer
learning_rate = 1e-3
# Number of epochs for training
nepochs = 20

# input size
num_emb = 176
# Embedding Size
hidden_size = 256
# Number of transformer blocks
num_layers = 8
# MultiheadAttention Heads
num_heads = 8

In [8]:
from transformer_decoder_training.models.transformer_decoder_1 import Transformer

model = Transformer(num_emb=num_emb, num_layers=num_layers, hidden_size=hidden_size, num_heads=num_heads).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Define the loss function
# loss function should be one that can handle multi one hot encoded vectors
# Klammern nicht vergessen
loss_fn = nn.BCELoss()

# Training

In [9]:
def train_loop(model, opt, loss_fn, dataloader, pad_token, device):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        # Move data to GPU
        src_sequence = batch.to(device)
        
        # create input and expected sequence -> move expected sequence one to the right
        input_sequences = src_sequence[:, :-1]
        expected_sequence = src_sequence[:, 1:]
        
        # Generate predictions
        pred = model(input_sequences, pad_token)
        
        #print("Prediction shape:", pred.shape)
        #print(pred)
        #print("expected harmony_shape:", expected_harmony.shape)
        #print(expected_harmony)
        
        # Calculate loss with masked cross-entropy
        # ich glaube 0 steht in vorlage für padding token index -> habe ich hier anders
        #mask = (expected_harmony != pad_token).float() Maske verwenden, um Padding positions im output zu canceln
        # masked_pred = pred * mask
        loss = loss_fn(pred, expected_sequence)
        
        # Backpropagation
        opt.zero_grad()
        loss.backward()
        opt.step()
    
        total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

def validation_loop(model, loss_fn, dataloader,pad_token, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            # Move data to GPU
            src_sequence = batch.to(device)
            
            # Create input and expected sequences
            input_sequences = src_sequence[:, :-1, :]
            expected_sequence = src_sequence[:, 1:, :]
            
            # Generate predictions
            pred = model(input_sequences, pad_token)
            
            # Calculate loss without flattening
            loss = loss_fn(pred, expected_sequence)
            
            total_loss += loss.detach().item()
    
    return total_loss / len(dataloader)

In [10]:
from timeit import default_timer as timer
NUM_EPOCHS = 21

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_loop(model, optimizer, loss_fn, train_loader, pad_token, device)
    end_time = timer()
    val_loss = validation_loop(model, loss_fn, val_loader, pad_token, device)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

Epoch: 1, Train loss: 0.061, Val loss: 0.044, Epoch time = 43.794s
Epoch: 2, Train loss: 0.029, Val loss: 0.024, Epoch time = 43.577s
Epoch: 3, Train loss: 0.023, Val loss: 0.022, Epoch time = 43.529s
Epoch: 4, Train loss: 0.022, Val loss: 0.022, Epoch time = 43.496s
Epoch: 5, Train loss: 0.021, Val loss: 0.021, Epoch time = 43.515s
Epoch: 6, Train loss: 0.021, Val loss: 0.020, Epoch time = 43.473s
Epoch: 7, Train loss: 0.021, Val loss: 0.020, Epoch time = 43.474s
Epoch: 8, Train loss: 0.020, Val loss: 0.020, Epoch time = 43.493s
Epoch: 9, Train loss: 0.020, Val loss: 0.020, Epoch time = 43.503s
Epoch: 10, Train loss: 0.020, Val loss: 0.019, Epoch time = 43.481s
Epoch: 11, Train loss: 0.020, Val loss: 0.019, Epoch time = 43.502s
Epoch: 12, Train loss: 0.019, Val loss: 0.019, Epoch time = 43.455s
Epoch: 13, Train loss: 0.019, Val loss: 0.019, Epoch time = 43.458s
Epoch: 14, Train loss: 0.019, Val loss: 0.019, Epoch time = 43.489s
Epoch: 15, Train loss: 0.019, Val loss: 0.019, Epoch time

In [11]:
# see: https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html#save-and-load-the-model

torch.save(model.state_dict(), "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models/model_1_notebook_v7.pth")