In [1]:
import time
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy

from vqvae import vqvae

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
with open('dataset/TRAIN.pkl', 'rb') as file:
    train_data = pickle.load(file)

with open('dataset/TEST.pkl', 'rb') as file:
    test_data = pickle.load(file)

val_data = test_data[:len(test_data) // 2]
test_data = test_data[len(test_data) // 2:]

In [4]:
X_train = train_data[:, 1] / 2400
X_test = test_data[:, 1] / 2400
X_val = val_data[:, 1] / 2400

y_train = train_data[:, 0]
y_test = test_data[:, 0]
y_val = val_data[:, 0]

In [5]:
train_mask = y_train != 7
X_train = X_train[train_mask]
y_train = y_train[train_mask]

test_mask = y_test != 7
X_test = X_test[test_mask]
y_test = y_test[test_mask]

val_mask = y_val != 7
X_val = X_val[val_mask]
y_val = y_val[val_mask]

In [6]:
X_train_lens = torch.tensor([len(x) // 4 for x in X_train], dtype=torch.long, device=device)
X_val_lens = torch.tensor([len(x) // 4 for x in X_val], dtype=torch.long, device=device)
X_test_lens = torch.tensor([len(x) // 4 for x in X_test], dtype=torch.long, device=device)

In [7]:
def pad_or_truncate(x, max_len):
    if len(x) > max_len:
        return x[:max_len]
    else:
        return np.pad(x, (0, max_len - len(x)), mode='constant', constant_values=x[-1])

In [8]:
max_len = max(len(x) for x in X_test)
max_len = max_len - (max_len % 4)

X_train = np.array([pad_or_truncate(x, max_len) for x in X_train])
X_test = np.array([pad_or_truncate(x, max_len) for x in X_test])
X_val = np.array([pad_or_truncate(x, max_len) for x in X_val])

In [9]:
X_train = torch.tensor(X_train, dtype=torch.float).to(device)
y_train = np.array(y_train, dtype=np.int64)
y_train = torch.tensor(y_train, dtype=torch.long).to(device)

In [10]:
X_test = torch.tensor(X_test, dtype=torch.float).to(device)
y_test = np.array(y_test, dtype=np.int64)
y_test = torch.tensor(y_test, dtype=torch.long).to(device)

In [11]:
X_val = torch.tensor(X_val, dtype=torch.float).to(device)
y_val = np.array(y_val, dtype=np.int64)
y_val = torch.tensor(y_val, dtype=torch.long).to(device)

In [12]:
X_test.shape

torch.Size([374, 692])

In [13]:
print(X_train.shape)
print(y_train.shape)

torch.Size([1106, 692])
torch.Size([1106])


In [14]:
tokenizer = torch.load('checkpoints/vqvae_1000.pth', map_location=device, weights_only=False)

In [15]:
z = tokenizer.encoder(X_train)
vq_loss, quantized_train, perplexity, embedding_weight, encoding_indices, encodings = tokenizer.vq(z)
quantized_train /= 10

In [16]:
z = tokenizer.encoder(X_val)
vq_loss, quantized_val, perplexity, embedding_weight, encoding_indices, encodings = tokenizer.vq(z)
quantized_val /= 10

In [17]:
from transformer import Transformer

In [18]:
decoder = Transformer(
    d_in=64,           # input feature dimension
    d_model=256,       # hidden dimension
    nhead=4,           # number of attention headshttps://mtg.github.io/IAM-tutorial-ismir22/landing.html
    d_hid=256,         # FFN size
    nlayers=4,         # number of layers
    seq_in_len=173,    # input sequence length
    batch_first=True
).to(device)

In [19]:
train_dataset = torch.utils.data.TensorDataset(quantized_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)

val_dataset = torch.utils.data.TensorDataset(quantized_val, y_val)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64)

In [20]:
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)

In [None]:
best_model = None
best_val_loss = float('inf')

for epoch in range(100):
    decoder.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    val_loss = 0
    val_correct = 0
    val_total = 0

    for xtrain_batch, ytrain_batch in tqdm(train_loader):
        optimizer.zero_grad()
        xtrain = xtrain_batch.permute(2, 0, 1)
        #mask_batch = mask_batch.permute(1, 0)
        output = decoder(xtrain)
        loss = criterion(output, ytrain_batch)
        loss.backward(retain_graph=True)
        optimizer.step()

        train_loss += loss.item()
        preds = output.argmax(dim=1)
        train_correct += (preds == ytrain_batch).sum().item()
        train_total += ytrain_batch.size(0)

    for xval_batch, yval_batch in val_loader:
        decoder.eval()
        with torch.no_grad():
            xval = xval_batch.permute(2, 0, 1)
            #mask_val_batch = mask_val_batch.permute(1, 0)
            #for i in mask_val_batch:
             #   print(i) 
            val_output = decoder(xval)
            #print(val_output)
            val_batch_loss = criterion(val_output, yval_batch)

            val_loss += val_batch_loss.item()
            val_preds = val_output.argmax(dim=1)
            val_correct += (val_preds == yval_batch).sum().item()
            val_total += yval_batch.size(0)

    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)

    train_accuracy = 100 * train_correct / train_total
    val_accuracy = 100 * val_correct / val_total

    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}% "
          f"| Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model = deepcopy(decoder)


100%|███████████████████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.05it/s]


Epoch 1 | Train Loss: 2.1443 | Train Acc: 15.55% | Val Loss: 2.0986 | Val Acc: 14.95%


100%|███████████████████████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.09it/s]


Epoch 2 | Train Loss: 2.0463 | Train Acc: 16.00% | Val Loss: 2.0640 | Val Acc: 13.32%


 11%|█████████▎                                                                          | 2/18 [00:00<00:05,  3.00it/s]

In [24]:
torch.save(best_model, os.path.join(f'checkpoints/transformerdropout.pth'))

In [25]:
z = tokenizer.encoder(X_test[:100])
vq_loss, quantized, perplexity, embedding_weight, encoding_indices, encodings = tokenizer.vq(z)

quantized = quantized.permute(2, 0, 1)
output = best_model(quantized) 

In [26]:
out = torch.nn.Softmax(dim=1)(output)
a, predicted = torch.max(out, 1)

In [27]:
correct_predictions = (predicted == y_test[:100]).sum().item()
correct_predictions /= 100

In [28]:
correct_predictions

0.73