In [1]:
import numpy as np
import torch

from data_preperation import dataset_snapshot
from transformer_decoder_training.dataprep_transformer import dataprep_1
from sklearn.model_selection import train_test_split

# Check if GPU is available, set device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

#load data
dataset_as_snapshots = dataset_snapshot.process_dataset_multithreaded("/home/falaxdb/Repos/minus1/datasets/maestro_v3_split/hands_split_into_seperate_midis", 0.05, amount=20)
# filter snapshots to 88 piano notes
dataset_as_snapshots = dataset_snapshot.filter_piano_range(dataset_as_snapshots)

# split songs into train, test and val
train_data, temp_data = train_test_split(dataset_as_snapshots, test_size=0.3, random_state=42, shuffle=True)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42, shuffle=True)

Processed dataset (40/40): 100%|██████████| 40/40 [00:01<00:00, 32.14it/s]


Processed 40 of 40 files


In [2]:
# Define special Tokens
# Token dimension needs to fit Data
sos_token = np.full((1, 176), 1)
pad_token = np.full((1, 176), 2)
pad_token = torch.tensor(pad_token, device=device)

# Define other parameters
batch_size = 64
seq_length = 512
stride = 256

# create dataset + dataloader
from torch.utils.data import DataLoader
from transformer_decoder_training.dataset_transformer.dataset_2 import AdvancedPianoDataset

train_dataset = AdvancedPianoDataset(train_data, seq_length, stride, sos_token)
val_dataset = AdvancedPianoDataset(val_data, seq_length, stride, sos_token)
test_dataset = AdvancedPianoDataset(test_data, seq_length, stride, sos_token)

# Create DataLoaders for each subset with drop_last=True
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [3]:
# Embedding Size
hidden_size = 256
# input size
num_emb = 176
# Number of transformer blocks
num_layers = 8
# MultiheadAttention Heads
num_heads = 8

from transformer_decoder_training.models.transformer_decoder_1 import Transformer

model = Transformer(num_emb=num_emb, num_layers=num_layers, hidden_size=hidden_size, num_heads=num_heads).to(device)
model.load_state_dict(torch.load("/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models/model_1_notebook_v7.pth"))
model.eval()

Transformer(
  (embedding): Linear(in_features=176, out_features=256, bias=True)
  (pos_emb): SinusoidalPosEmb()
  (blocks): ModuleList(
    (0-7): 8 x TransformerBlock(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (multihead_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ELU(alpha=1.0)
        (2): Linear(in_features=1024, out_features=256, bias=True)
      )
    )
  )
  (fc_out): Linear(in_features=256, out_features=176, bias=True)
  (sigmoid): Sigmoid()
)

In [4]:
from transformer_decoder_training.inference.inference_2 import inference

original_complete_sequence = ""
context_seq = ""
generated_continuing_sequence_complete = ""
last_input_sequence = ""

# just do one single sequence
for batch in train_loader:
    # get single sequence
    # blow it up to one batch again
    original_complete_sequence = batch[0]
    sequence = torch.unsqueeze(batch[0], 0)
    print(sequence.shape)
    
    # split into context sequence and truth sequence
    context_seq = sequence[: ,:200]
    print("context seq shape:", context_seq.shape)
    continuing_seq = sequence[:, 200:]
    print("continuing seq shape:", continuing_seq.shape)
    
    output_tokens, harmony_output_tokens, last_input_sequence = inference(model, context_seq, continuing_seq, 0.1, pad_token, device)
    
    generated_continuing_sequence_complete = torch.cat(output_tokens)
    print("Output tokens:", generated_continuing_sequence_complete.shape)
    print("last input sequence:", last_input_sequence.shape)
    
    
    break

torch.Size([1, 513, 176])
context seq shape: torch.Size([1, 200, 176])
continuing seq shape: torch.Size([1, 313, 176])
Tokens to generate: 313
iteration: 0
Next token before splitting: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda

In [5]:
# squeeze the sequences to remove batch
context_seq = context_seq.squeeze(0)
last_input_sequence = last_input_sequence.squeeze(0)

In [6]:
# Convert back to midi
from data_visualization import snapshot_to_midi

track_names = ["Original complete sequence melody", 
               "Original complete sequence harmony", 
               "Context sequence melody", 
               "Context sequence harmony", 
               "Last input seqeunce melody",
               "Last input seqeunce harmony",]



ori_complete_seq_mel, ori_complete_seq_har = snapshot_to_midi.split_snapshots_in_sequence(original_complete_sequence.cpu().numpy())
context_seq_mel, context_seq_har = snapshot_to_midi.split_snapshots_in_sequence(context_seq.cpu().numpy())
last_input_seq_mel, last_input_seq_har = snapshot_to_midi.split_snapshots_in_sequence(last_input_sequence.cpu().numpy())

print("last input seq harmony snapshot:", last_input_seq_har[-1])

# make list of tracks
tracks = [ori_complete_seq_mel, ori_complete_seq_har, context_seq_mel, context_seq_har, last_input_seq_mel, last_input_seq_har]

for track in tracks:
    print(track.shape)

# create midi file
snapshot_to_midi.create_midi_from_snapshots(tracks, track_names, 0.05, "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/midi_outputs", "model_1_notebook_v7.mid")

last input seq harmony snapshot: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
(513, 88)
(513, 88)
(200, 88)
(200, 88)
(513, 88)
(513, 88)
MIDI file saved to /home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/midi_outputs/model_1_notebook_v7.mid
