# now with new dataset preparation

I now want to do correct sequence prediction and not just Right hand to left hand

## for training

I just use one Sequence with sos token as input and the same sequence shifted 1 to the right as prediction target.

For this i use the new dataset

In [1]:
# imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

import math
import numpy as np

from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
import random

In [2]:
# Check if GPU is available, set device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

# Load Data

In [3]:
from data_preperation import dataset_snapshot
from transformer_decoder_training.dataprep_transformer import dataprep_1
from sklearn.model_selection import train_test_split

#load data
dataset_as_snapshots = dataset_snapshot.process_dataset_multithreaded("/home/falaxdb/Repos/minus1/datasets/maestro_v3_split/hands_split_into_seperate_midis", 0.05)
# filter snapshots to 88 piano notes
dataset_as_snapshots = dataset_snapshot.filter_piano_range(dataset_as_snapshots)

dataset_as_snapshots = dataset_snapshot.compress_existing_dataset_to_12keys(dataset_as_snapshots)

Processed dataset (1038/1038): 100%|██████████| 1038/1038 [00:14<00:00, 72.08it/s]


Processed 1038 of 1038 files


In [4]:
# split songs into train, test and val
train_data, temp_data = train_test_split(dataset_as_snapshots, test_size=0.3, random_state=42, shuffle=True)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42, shuffle=True)

# see if split is correct
print("Train data:", len(train_data))
print("test data:", len(test_data))
print("val data:", len(val_data))

for song in train_data:
    for track in song:
        print(track.shape)
    break

Train data: 363
test data: 78
val data: 78
(8312, 12)
(8312, 12)


## Create Dataset

In [5]:
# Define special Tokens
# Token dimension needs to fit Data
sos_token = np.full((1, 24), 1)
pad_token = np.full((1, 24), 2)
pad_token = torch.tensor(pad_token, device=device)

# Define other parameters
batch_size = 64
seq_length = 512
stride = 256

In [6]:
# create dataset + dataloader
from torch.utils.data import DataLoader
from transformer_decoder_training.dataset_transformer.dataset_2 import AdvancedPianoDataset

train_dataset = AdvancedPianoDataset(train_data, seq_length, stride, sos_token)
val_dataset = AdvancedPianoDataset(val_data, seq_length, stride, sos_token)
test_dataset = AdvancedPianoDataset(test_data, seq_length, stride, sos_token)

print("Check length of datasets. should roughly match split ratio")
print("train dataset:", len(train_dataset))
print("val dataset:", len(val_dataset))
print("test dataset:", len(test_dataset))
print("")

# Create DataLoaders for each subset with drop_last=True
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

# Test if data looks correct
# sos token should be at beginning of every sequence
# sequence should be 2 times the size of a track snapshot
for batch in train_loader:
    print("Visualize shape of batch:")
    print("shape of one batch:", batch.shape)
    print("==============")
    
    print("Test for sos token as first token in sequence")
    print("First token in seq:", batch[0][0])
    print("=============")
    
    print("Test print one snapshot:")
    print("First half of values should be left hand, second half should be right hand")
    print(batch[0][1])
    break


Check length of datasets. should roughly match split ratio
train dataset: 13440
val dataset: 3237
test dataset: 3179

Visualize shape of batch:
shape of one batch: torch.Size([64, 513, 24])
Test for sos token as first token in sequence
First token in seq: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.])
Test print one snapshot:
First half of values should be left hand, second half should be right hand
tensor([0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0.])


# Initialize Model

In [7]:
# set parameters
# Learning rate for the optimizer
learning_rate = 1e-3
# Number of epochs for training
nepochs = 20
# Embedding Size
hidden_size = 256
# Number of transformer blocks
num_layers = 8
# MultiheadAttention Heads
num_heads = 8

In [8]:
from transformer_decoder_training.models.transformer_decoder_1 import Transformer

model = Transformer(num_emb=24, num_layers=num_layers, hidden_size=hidden_size, num_heads=num_heads).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Define the loss function
# loss function should be one that can handle multi one hot encoded vectors
# Klammern nicht vergessen
loss_fn = nn.BCELoss()

In [9]:
# check number of model parameters
num_model_params = 0
for param in model.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

-This Model Has 6330648 (Approximately 6 Million) Parameters!


# Training

In [10]:
def train_loop(model, opt, loss_fn, dataloader, pad_token, device):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        # Move data to GPU
        src_sequence = batch.to(device)
        
        # create input and expected sequence -> move expected sequence one to the right
        input_sequences = src_sequence[:, :-1]
        expected_sequence = src_sequence[:, 1:]
        
        # Generate predictions
        pred = model(input_sequences, pad_token)
        
        #print("Prediction shape:", pred.shape)
        #print(pred)
        #print("expected harmony_shape:", expected_harmony.shape)
        #print(expected_harmony)
        
        # Calculate loss with masked cross-entropy
        # ich glaube 0 steht in vorlage für padding token index -> habe ich hier anders
        #mask = (expected_harmony != pad_token).float() Maske verwenden, um Padding positions im output zu canceln
        # masked_pred = pred * mask
        loss = loss_fn(pred, expected_sequence)
        
        # Backpropagation
        opt.zero_grad()
        loss.backward()
        opt.step()
    
        total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

def validation_loop(model, loss_fn, dataloader,pad_token, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            # Move data to GPU
            src_sequence = batch.to(device)
            
            # Create input and expected sequences
            input_sequences = src_sequence[:, :-1, :]
            expected_sequence = src_sequence[:, 1:, :]
            
            # Generate predictions
            pred = model(input_sequences, pad_token)
            
            # Calculate loss without flattening
            loss = loss_fn(pred, expected_sequence)
            
            total_loss += loss.detach().item()
    
    return total_loss / len(dataloader)

In [11]:
from timeit import default_timer as timer
NUM_EPOCHS = 21

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_loop(model, optimizer, loss_fn, train_loader, pad_token, device)
    end_time = timer()
    val_loss = validation_loop(model, loss_fn, val_loader, pad_token, device)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

Epoch: 1, Train loss: 0.183, Val loss: 0.128, Epoch time = 41.395s
Epoch: 2, Train loss: 0.127, Val loss: 0.125, Epoch time = 42.138s
Epoch: 3, Train loss: 0.123, Val loss: 0.122, Epoch time = 42.126s
Epoch: 4, Train loss: 0.120, Val loss: 0.119, Epoch time = 42.132s
Epoch: 5, Train loss: 0.119, Val loss: 0.118, Epoch time = 42.125s
Epoch: 6, Train loss: 0.117, Val loss: 0.117, Epoch time = 42.103s
Epoch: 7, Train loss: 0.116, Val loss: 0.115, Epoch time = 42.102s
Epoch: 8, Train loss: 0.114, Val loss: 0.114, Epoch time = 42.102s
Epoch: 9, Train loss: 0.113, Val loss: 0.112, Epoch time = 42.090s
Epoch: 10, Train loss: 0.111, Val loss: 0.110, Epoch time = 42.103s
Epoch: 11, Train loss: 0.109, Val loss: 0.109, Epoch time = 42.109s
Epoch: 12, Train loss: 0.108, Val loss: 0.107, Epoch time = 42.099s
Epoch: 13, Train loss: 0.106, Val loss: 0.107, Epoch time = 42.101s
Epoch: 14, Train loss: 0.105, Val loss: 0.105, Epoch time = 42.096s
Epoch: 15, Train loss: 0.104, Val loss: 0.105, Epoch time

### Last training output:
Learning rate: 1e-4

Epoch: 1, Train loss: 0.126, Val loss: 0.133, Epoch time = 44.974s  
Epoch: 2, Train loss: 0.124, Val loss: 0.132, Epoch time = 44.998s  
Epoch: 3, Train loss: 0.123, Val loss: 0.130, Epoch time = 44.866s  
Epoch: 4, Train loss: 0.121, Val loss: 0.129, Epoch time = 44.857s  
Epoch: 5, Train loss: 0.120, Val loss: 0.127, Epoch time = 44.856s  
Epoch: 6, Train loss: 0.118, Val loss: 0.126, Epoch time = 44.866s  
Epoch: 7, Train loss: 0.118, Val loss: 0.125, Epoch time = 44.853s  
Epoch: 8, Train loss: 0.117, Val loss: 0.125, Epoch time = 44.867s  
Epoch: 9, Train loss: 0.116, Val loss: 0.124, Epoch time = 44.860s  
Epoch: 10, Train loss: 0.115, Val loss: 0.123, Epoch time = 44.859s 

learning rate: 1e-3
## Training and Validation Losses over 200 Epochs

| Epoch | Train Loss | Val Loss | Epoch Time (s) |
|-------|------------|----------|----------------|
| 1     | 0.197      | 0.127    | 43.844         |
| 2     | 0.128      | 0.123    | 43.571         |
| 3     | 0.125      | 0.120    | 43.708         |
| 4     | 0.122      | 0.118    | 43.733         |
| 5     | 0.120      | 0.116    | 43.740         |
| 6     | 0.119      | 0.115    | 43.730         |
| 7     | 0.117      | 0.113    | 43.722         |
| 8     | 0.116      | 0.113    | 43.719         |
| 9     | 0.114      | 0.111    | 43.728         |
| 10    | 0.113      | 0.109    | 43.719         |
| 11    | 0.111      | 0.108    | 43.738         |
| 12    | 0.110      | 0.107    | 43.723         |
| 13    | 0.109      | 0.106    | 43.586         |
| 14    | 0.107      | 0.105    | 43.714         |
| 15    | 0.106      | 0.104    | 43.714         |
| 16    | 0.105      | 0.104    | 43.703         |
| 17    | 0.104      | 0.103    | 43.713         |
| 18    | 0.103      | 0.103    | 43.705         |
| 19    | 0.102      | 0.103    | 43.701         |
| 20    | 0.102      | 0.102    | 43.698         |
| 21    | 0.101      | 0.101    | 43.714         |
| 22    | 0.100      | 0.101    | 43.695         |
| 23    | 0.099      | 0.101    | 43.698         |
| 24    | 0.098      | 0.101    | 43.684         |
| 25    | 0.097      | 0.101    | 43.685         |
| 26    | 0.096      | 0.101    | 43.688         |
| 27    | 0.096      | 0.101    | 43.689         |
| 28    | 0.095      | 0.101    | 43.659         |
| 29    | 0.094      | 0.101    | 43.661         |
| 30    | 0.093      | 0.101    | 43.662         |
| 31    | 0.092      | 0.102    | 43.668         |
| 32    | 0.091      | 0.101    | 43.658         |
| 33    | 0.090      | 0.102    | 43.659         |
| 34    | 0.089      | 0.102    | 43.665         |
| 35    | 0.088      | 0.103    | 43.664         |
| 36    | 0.087      | 0.103    | 43.660         |
| 37    | 0.086      | 0.104    | 43.665         |
| 38    | 0.085      | 0.104    | 43.661         |
| 39    | 0.084      | 0.105    | 43.701         |
| 40    | 0.083      | 0.106    | 43.710         |
| 41    | 0.082      | 0.107    | 43.714         |
| 42    | 0.081      | 0.107    | 43.705         |
| 43    | 0.080      | 0.108    | 43.702         |
| 44    | 0.079      | 0.109    | 43.697         |
| 45    | 0.078      | 0.109    | 43.692         |
| 46    | 0.076      | 0.111    | 43.698         |
| 47    | 0.075      | 0.111    | 43.700         |
| 48    | 0.074      | 0.113    | 43.697         |
| 49    | 0.073      | 0.114    | 43.711         |
| 50    | 0.072      | 0.116    | 43.702         |
| 51    | 0.071      | 0.116    | 43.701         |
| 52    | 0.070      | 0.118    | 43.712         |
| 53    | 0.069      | 0.119    | 43.704         |
| 54    | 0.069      | 0.119    | 43.711         |
| 55    | 0.068      | 0.121    | 43.694         |
| 56    | 0.067      | 0.122    | 43.701         |
| 57    | 0.066      | 0.124    | 43.705         |
| 58    | 0.065      | 0.125    | 43.698         |
| 59    | 0.064      | 0.126    | 43.710         |
| 60    | 0.063      | 0.127    | 43.712         |
| 61    | 0.063      | 0.128    | 43.720         |
| 62    | 0.062      | 0.129    | 43.714         |
| 63    | 0.061      | 0.129    | 43.710         |
| 64    | 0.061      | 0.132    | 43.722         |
| 65    | 0.060      | 0.134    | 43.725         |
| 66    | 0.059      | 0.135    | 43.719         |
| 67    | 0.059      | 0.136    | 43.728         |
| 68    | 0.058      | 0.137    | 43.718         |
| 69    | 0.058      | 0.138    | 43.716         |
| 70    | 0.057      | 0.139    | 43.728         |
| 71    | 0.057      | 0.141    | 43.725         |
| 72    | 0.056      | 0.142    | 43.707         |
| 73    | 0.056      | 0.142    | 43.722         |
| 74    | 0.055      | 0.145    | 43.727         |
| 75    | 0.055      | 0.144    | 43.729         |
| 76    | 0.054      | 0.147    | 43.723         |
| 77    | 0.054      | 0.146    | 43.727         |
| 78    | 0.053      | 0.148    | 43.729         |
| 79    | 0.053      | 0.149    | 43.730         |
| 80    | 0.052      | 0.150    | 43.729         |
| 81    | 0.052      | 0.151    | 43.725         |
| 82    | 0.052      | 0.153    | 43.724         |
| 83    | 0.051      | 0.153    | 43.731         |
| 84    | 0.051      | 0.155    | 43.727         |
| 85    | 0.051      | 0.155    | 43.730         |
| 86    | 0.050      | 0.155    | 43.722         |
| 87    | 0.050      | 0.157    | 43.727         |
| 88    | 0.049      | 0.158    | 43.724         |
| 89    | 0.049      | 0.160    | 43.727         |
| 90    | 0.049      | 0.160    | 43.720         |
| 91    | 0.049      | 0.161    | 43.727         |
| 92    | 0.048      | 0.162    | 43.728         |
| 93    | 0.048      | 0.163    | 43.687         |
| 94    | 0.048      | 0.164    | 43.690         |
| 95    | 0.047      | 0.165    | 43.691         |
| 96    | 0.047      | 0.168    | 43.693         |
| 97    | 0.047      | 0.167    | 43.686         |
| 98    | 0.047      | 0.164    | 43.693         |
| 99    | 0.046      | 0.169    | 43.690         |
| 100   | 0.046      | 0.169    | 43.700         |
| 101   | 0.046      | 0.170    | 43.717         |
| 102   | 0.046      | 0.170    | 43.726         |
| 103   | 0.045      | 0.172    | 43.721         |
| 104   | 0.045      | 0.172    | 43.730         |
| 105   | 0.045      | 0.173    | 43.723         |
| 106   | 0.045      | 0.174    | 43.718         |
| 107   | 0.044      | 0.175    | 43.730         |
| 108   | 0.044      | 0.175    | 43.725         |
| 109   | 0.044      | 0.177    | 43.726         |
| 110   | 0.044      | 0.178    | 43.720         |
| 111   | 0.044      | 0.178    | 43.729         |
| 112   | 0.043      | 0.179    | 43.727         |
| 113   | 0.043      | 0.179    | 43.735         |
| 114   | 0.043      | 0.180    | 43.736         |
| 115   | 0.043      | 0.181    | 43.726         |
| 116   | 0.043      | 0.182    | 43.731         |
| 117   | 0.043      | 0.182    | 43.733         |
| 118   | 0.042      | 0.183    | 43.731         |
| 119   | 0.042      | 0.183    | 43.724         |
| 120   | 0.042      | 0.185    | 43.733         |
| 121   | 0.042      | 0.185    | 43.728         |
| 122   | 0.042      | 0.186    | 43.735         |
| 123   | 0.041      | 0.186    | 43.724         |
| 124   | 0.041      | 0.188    | 43.726         |
| 125   | 0.041      | 0.187    | 43.722         |
| 126   | 0.041      | 0.189    | 43.721         |
| 127   | 0.041      | 0.188    | 43.729         |
| 128   | 0.041      | 0.190    | 43.725         |
| 129   | 0.041      | 0.191    | 43.736         |
| 130   | 0.040      | 0.190    | 43.726         |
| 131   | 0.040      | 0.191    | 43.721         |
| 132   | 0.040      | 0.191    | 43.736         |
| 133   | 0.040      | 0.191    | 43.728         |
| 134   | 0.040      | 0.193    | 43.728         |
| 135   | 0.040      | 0.194    | 43.717         |
| 136   | 0.040      | 0.194    | 43.731         |
| 137   | 0.039      | 0.196    | 43.730         |
| 138   | 0.039      | 0.197    | 43.733         |
| 139   | 0.039      | 0.197    | 43.722         |
| 140   | 0.039      | 0.198    | 43.736         |
| 141   | 0.039      | 0.198    | 43.733         |
| 142   | 0.039      | 0.198    | 43.721         |
| 143   | 0.039      | 0.199    | 43.727         |
| 144   | 0.039      | 0.200    | 43.724         |
| 145   | 0.038      | 0.199    | 43.731         |
| 146   | 0.038      | 0.202    | 43.731         |
| 147   | 0.038      | 0.201    | 43.739         |
| 148   | 0.038      | 0.200    | 43.738         |
| 149   | 0.038      | 0.204    | 43.725         |
| 150   | 0.038      | 0.204    | 43.725         |
| 151   | 0.038      | 0.202    | 43.725         |
| 152   | 0.038      | 0.204    | 43.704         |
| 153   | 0.038      | 0.206    | 43.699         |
| 154   | 0.037      | 0.206    | 43.706         |
| 155   | 0.037      | 0.207    | 43.701         |
| 156   | 0.037      | 0.206    | 43.692         |
| 157   | 0.037      | 0.206    | 43.700         |
| 158   | 0.037      | 0.206    | 43.694         |
| 159   | 0.037      | 0.209    | 43.710         |
| 160   | 0.037      | 0.208    | 43.707         |
| 161   | 0.037      | 0.209    | 43.707         |
| 162   | 0.037      | 0.209    | 43.701         |
| 163   | 0.037      | 0.209    | 43.693         |
| 164   | 0.036      | 0.212    | 43.704         |
| 165   | 0.036      | 0.212    | 43.691         |
| 166   | 0.036      | 0.211    | 43.708         |
| 167   | 0.036      | 0.212    | 43.707         |
| 168   | 0.036      | 0.211    | 43.700         |
| 169   | 0.036      | 0.211    | 43.699         |
| 170   | 0.036      | 0.212    | 43.721         |
| 171   | 0.036      | 0.213    | 43.730         |
| 172   | 0.036      | 0.215    | 43.720         |
| 173   | 0.036      | 0.214    | 43.717         |
| 174   | 0.036      | 0.217    | 43.716         |
| 175   | 0.035      | 0.215    | 43.727         |
| 176   | 0.035      | 0.215    | 43.711         |
| 177   | 0.035      | 0.218    | 43.713         |
| 178   | 0.035      | 0.218    | 43.722         |
| 179   | 0.035      | 0.217    | 43.714         |
| 180   | 0.035      | 0.217    | 43.716         |
| 181   | 0.035      | 0.219    | 43.707         |
| 182   | 0.035      | 0.216    | 43.717         |
| 183   | 0.035      | 0.219    | 43.601         |
| 184   | 0.035      | 0.219    | 43.601         |
| 185   | 0.035      | 0.221    | 43.604         |
| 186   | 0.035      | 0.220    | 43.604         |
| 187   | 0.035      | 0.222    | 43.602         |
| 188   | 0.034      | 0.222    | 43.599         |
| 189   | 0.034      | 0.223    | 43.598         |
| 190   | 0.034      | 0.224    | 43.608         |
| 191   | 0.034      | 0.223    | 43.606         |
| 192   | 0.034      | 0.223    | 43.593         |
| 193   | 0.034      | 0.225    | 43.609         |
| 194   | 0.034      | 0.223    | 43.607         |
| 195   | 0.034      | 0.224    | 43.604         |
| 196   | 0.034      | 0.225    | 43.609         |
| 197   | 0.034      | 0.224    | 43.605         |
| 198   | 0.034      | 0.226    | 43.607         |
| 199   | 0.034      | 0.225    | 43.600         |
| 200   | 0.034      | 0.225    | 43.609         |

-> 21 Epochs wahrscheinlich optimal



In [12]:
# see: https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html#save-and-load-the-model

torch.save(model.state_dict(), "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models/model_1_notebook_v6.pth")

# Now training without sigmoid output function but BCEWithLogitsLoss

And apply sigmoid function only during training

In [13]:
from transformer_decoder_training.models.transformer_decoder_2 import Transformer as Transformer_2
# Use Transformer without sigmoid
model = Transformer_2(num_emb=24, num_layers=num_layers, hidden_size=hidden_size, num_heads=num_heads).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Define the loss function
# loss function should be one that can handle multi one hot encoded vectors
# Klammern nicht vergessen
loss_fn = nn.BCEWithLogitsLoss()

In [14]:
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_loop(model, optimizer, loss_fn, train_loader, pad_token, device)
    end_time = timer()
    val_loss = validation_loop(model, loss_fn, val_loader, pad_token, device)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

Epoch: 1, Train loss: 0.196, Val loss: 0.129, Epoch time = 41.117s
Epoch: 2, Train loss: 0.127, Val loss: 0.126, Epoch time = 41.439s
Epoch: 3, Train loss: 0.124, Val loss: 0.122, Epoch time = 42.159s
Epoch: 4, Train loss: 0.121, Val loss: 0.120, Epoch time = 42.133s
Epoch: 5, Train loss: 0.119, Val loss: 0.119, Epoch time = 42.143s
Epoch: 6, Train loss: 0.118, Val loss: 0.118, Epoch time = 42.134s
Epoch: 7, Train loss: 0.117, Val loss: 0.116, Epoch time = 42.145s
Epoch: 8, Train loss: 0.115, Val loss: 0.115, Epoch time = 42.127s
Epoch: 9, Train loss: 0.114, Val loss: 0.114, Epoch time = 42.123s
Epoch: 10, Train loss: 0.113, Val loss: 0.112, Epoch time = 41.722s
Epoch: 11, Train loss: 0.111, Val loss: 0.111, Epoch time = 39.951s
Epoch: 12, Train loss: 0.110, Val loss: 0.110, Epoch time = 42.101s
Epoch: 13, Train loss: 0.109, Val loss: 0.109, Epoch time = 42.108s
Epoch: 14, Train loss: 0.107, Val loss: 0.108, Epoch time = 42.119s
Epoch: 15, Train loss: 0.106, Val loss: 0.107, Epoch time

In [None]:
#Output:
"""
Epoch: 1, Train loss: 0.196, Val loss: 0.129, Epoch time = 41.117s
Epoch: 2, Train loss: 0.127, Val loss: 0.126, Epoch time = 41.439s
Epoch: 3, Train loss: 0.124, Val loss: 0.122, Epoch time = 42.159s
Epoch: 4, Train loss: 0.121, Val loss: 0.120, Epoch time = 42.133s
Epoch: 5, Train loss: 0.119, Val loss: 0.119, Epoch time = 42.143s
Epoch: 6, Train loss: 0.118, Val loss: 0.118, Epoch time = 42.134s
Epoch: 7, Train loss: 0.117, Val loss: 0.116, Epoch time = 42.145s
Epoch: 8, Train loss: 0.115, Val loss: 0.115, Epoch time = 42.127s
Epoch: 9, Train loss: 0.114, Val loss: 0.114, Epoch time = 42.123s
Epoch: 10, Train loss: 0.113, Val loss: 0.112, Epoch time = 41.722s
Epoch: 11, Train loss: 0.111, Val loss: 0.111, Epoch time = 39.951s
Epoch: 12, Train loss: 0.110, Val loss: 0.110, Epoch time = 42.101s
Epoch: 13, Train loss: 0.109, Val loss: 0.109, Epoch time = 42.108s
Epoch: 14, Train loss: 0.107, Val loss: 0.108, Epoch time = 42.119s
Epoch: 15, Train loss: 0.106, Val loss: 0.107, Epoch time = 42.126s
Epoch: 16, Train loss: 0.105, Val loss: 0.106, Epoch time = 42.116s
Epoch: 17, Train loss: 0.104, Val loss: 0.105, Epoch time = 42.104s
Epoch: 18, Train loss: 0.103, Val loss: 0.105, Epoch time = 42.111s
Epoch: 19, Train loss: 0.102, Val loss: 0.104, Epoch time = 42.111s
Epoch: 20, Train loss: 0.101, Val loss: 0.104, Epoch time = 42.113s
Epoch: 21, Train loss: 0.100, Val loss: 0.103, Epoch time = 42.102s
"""

Modell scheint sich im Training nicht signifikant anders zu verhalten -> inferenz muss noch getestet werden



In [15]:
torch.save(model.state_dict(), "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models/model_1_notebook_v6_no_sigmoid.pth")