In [1]:
# import libraries
import gc
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from tqdm import tqdm
import lightning.pytorch as pl

gc.collect()
torch.cuda.empty_cache()

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
CUDA = True
DEVICE = torch.device("cuda" if CUDA else "cpu")
BATCH_SIZE = 4
X_DIM = 1024
EMBED_DIM = 768
HIDDEN_DIM = 1024
VOCAB_SIZE = 560
LATENT_DIM = HIDDEN_DIM//4
COMPUTE_LOGITS = False
DROPOUT = 0.2
LR = 1e-5
EPOCHS = 3000
EPOCH_BEGIN = 0

PARAMS = {
    'batch': BATCH_SIZE,
    'x_dim': X_DIM,
    'embed_dim': EMBED_DIM,
    'hidden_dim': HIDDEN_DIM,
    'vocab_size': VOCAB_SIZE,
    'latent_dim': LATENT_DIM,
    'dropout': DROPOUT,
    'lr': LR
}

DEVICE


device(type='cuda')

In [3]:
from torch import optim, Tensor
from typing import List, Dict, Any


class Encoder(nn.Module):
    def __init__(self, hidden_dims: List = [HIDDEN_DIM], latent_dim=64):
        super(Encoder, self).__init__()

        modules = []

        modules.append(
            nn.Linear(EMBED_DIM, hidden_dims[0], bias=False)
        )

        for i in range(0, len(hidden_dims)):
            modules.append(
                nn.Sequential(
                    nn.Linear(
                        hidden_dims[i] if i == 0 else hidden_dims[i-1]//2,
                        hidden_dims[i]//2,
                        bias=False
                    ),
                    nn.Dropout(DROPOUT/(i+1))
                ),
            )

        self.module = nn.Sequential(*modules, nn.LeakyReLU(0.2))
        self.fc_mean = nn.Linear(hidden_dims[-1]//2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dims[-1]//2, latent_dim)

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        hidden = self.module(input)
        mean = self.fc_mean(hidden)
        logvar = self.fc_logvar(hidden)

        return hidden, mean, logvar


class Decoder(nn.Module):
    def __init__(self, hidden_dims: List = [HIDDEN_DIM], latent_dim=64):
        super(Decoder, self).__init__()

        modules = []

        hidden_dims.reverse()

        modules.append(
            nn.Linear(latent_dim, hidden_dims[0]//2, bias=False)
        )

        for i in range(0, len(hidden_dims)):
            modules.append(
                nn.Sequential(
                    nn.Linear(
                        hidden_dims[i]//2 if i == 0 else hidden_dims[i-1],
                        hidden_dims[i],
                        bias=False
                    ),
                    nn.Dropout((DROPOUT/len(hidden_dims))*(i+1))
                )
            )

        modules.append(
            nn.LeakyReLU(0.2)
        )

        self.module = nn.Sequential(
            *modules,
            nn.Linear(hidden_dims[-1], EMBED_DIM, bias=False)
        )

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        result = self.module(input)

        return result


class VAE(pl.LightningModule):

    def __init__(self,
                 latent_dim: int,
                 hidden_dims: List = [HIDDEN_DIM]
                 ) -> None:
        super(VAE, self).__init__()

        self.latent_dim = latent_dim
        self.emb = nn.Embedding(VOCAB_SIZE, EMBED_DIM, padding_idx=0)
        self.encoder = Encoder(hidden_dims, latent_dim)
        self.decoder = Decoder(hidden_dims, latent_dim)
        self.z_emb = nn.Linear(latent_dim, EMBED_DIM, bias=False)
        self.proj = nn.Linear(EMBED_DIM, VOCAB_SIZE, bias=True)
        self.ln_out = nn.LayerNorm(VOCAB_SIZE)

        self.emb.weight.data.uniform_(-0.1, 0.1)
        self.proj.bias.data.zero_()
        self.proj.weight.data.uniform_(-0.1, 0.1)

    def encode(self, input: Tensor) -> List[Tensor]:
        emb = self.emb(input)
        hidden, mean, logvar = self.encoder(emb)

        return [emb, hidden, mean, logvar]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder(z) + self.z_emb(z)

        return result

    def reparameterize(self, mean: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mean, var) from
        N(0,1).
        :param mean: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)

        return eps * std + mean

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        emb, hidden, mean, log_var = self.encode(input)
        z = self.reparameterize(mean, log_var)
        emb_hat = self.decode(z)
        output = self.proj(emb_hat)
        output = self.ln_out(output)
        output = torch.softmax(output, dim=-1)

        return [output, emb_hat, emb, hidden, mean, log_var]

    def kl_loss(self, mean: Tensor, logvar: Tensor) -> Tensor:
        kl_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

        return kl_loss

    def reconstruction_loss(self, input_emb: Tensor, output_emb: Tensor, input: Tensor, output: Tensor, padding_index=0) -> Tensor:
        loss_emb = F.mse_loss(output_emb, input_emb)

        if not COMPUTE_LOGITS:
            return loss_emb

        output = torch.argmax(output, dim=1)

        # Create a mask to exclude the padding_idx from loss computation
        mask = input != padding_index

        # Apply the mask to the input and output tensors
        input_masked = input.to(dtype=torch.float).view(-1)[mask.view(-1)]
        output_masked = output.to(
            dtype=torch.float).view(-1)[mask.view(-1)]

        # Compute the cross-entropy loss only for non-padding positions
        loss = F.cross_entropy(output_masked, input_masked)

        recon_loss = loss_emb + loss

        return recon_loss

    def loss_function(self, input_emb: Tensor, output_emb: Tensor, input: Tensor, output: Tensor, mean: Tensor, logvar: Tensor, padding_index=0) -> Tensor:
        recon_loss = self.reconstruction_loss(
            input_emb, output_emb, input, output, padding_index)
        kl_loss = self.kl_loss(mean, logvar)
        total_loss = recon_loss + kl_loss

        return total_loss

    def sample(self,
               num_samples: int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples, self.latent_dim)
        z = z.to(current_device)
        samples = self.decode(z)

        return samples

    def training_step(self, batch, batch_idx):
        input = batch[0][np.random.randint(0, len(batch[0]))]
        output, emb_hat, emb, hidden, mean, log_var = self(input)
        loss = self.loss_function(
            emb,
            emb_hat,
            input,
            output,
            mean, log_var, padding_index=0
        )

        self.log('train_loss', loss, prog_bar=True, on_step=True)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=LR, betas=(0.5, 0.999))
        self.trainer._scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=1, gamma=0.7)

        return [optimizer], [self.trainer._scheduler]


In [4]:
model = VAE(LATENT_DIM, hidden_dims=[
            HIDDEN_DIM*4, HIDDEN_DIM*2, HIDDEN_DIM, HIDDEN_DIM//2])

In [None]:
EPOCH_BEGIN = 15

load_dict = torch.load(f'embvae-{VOCAB_SIZE}x{EMBED_DIM}x{HIDDEN_DIM}-{EPOCH_BEGIN}.pth', map_location=DEVICE)
load_keys = load_dict.keys()

for k in model.state_dict():
    if k not in load_keys:
        load_dict[k] = model.state_dict()[k]

model.load_state_dict(load_dict, strict=True)

In [5]:
model

VAE(
  (emb): Embedding(560, 768, padding_idx=0)
  (encoder): Encoder(
    (module): Sequential(
      (0): Linear(in_features=768, out_features=4096, bias=False)
      (1): Sequential(
        (0): Linear(in_features=4096, out_features=2048, bias=False)
        (1): Dropout(p=0.2, inplace=False)
      )
      (2): Sequential(
        (0): Linear(in_features=2048, out_features=1024, bias=False)
        (1): Dropout(p=0.1, inplace=False)
      )
      (3): Sequential(
        (0): Linear(in_features=1024, out_features=512, bias=False)
        (1): Dropout(p=0.06666666666666667, inplace=False)
      )
      (4): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=False)
        (1): Dropout(p=0.05, inplace=False)
      )
      (5): LeakyReLU(negative_slope=0.2)
    )
    (fc_mean): Linear(in_features=256, out_features=256, bias=True)
    (fc_logvar): Linear(in_features=256, out_features=256, bias=True)
  )
  (decoder): Decoder(
    (module): Sequential(
      (0): Lin

In [6]:
from pathlib import Path
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import TensorDataset, DataLoader
import json

paths = list(
    Path('/home/nico/data/ai/models/midi/all/').glob('a*_mid.json'))
paths += list(
    Path('/home/nico/data/ai/models/midi/all/').glob('b*_mid.json'))
paths += list(
    Path('/home/nico/data/ai/models/midi/all/').glob('c*_mid.json'))

tokens = []

for path in paths:
  tokens += json.load(open(path))['ids']

ids = torch.LongTensor(tokens)
ids = torch.split(ids, X_DIM)
ids = pad_sequence(ids, batch_first=True)
dataset = TensorDataset(ids)
data_loader = DataLoader(dataset, pin_memory=True, batch_size=BATCH_SIZE)

ids.shape


torch.Size([33999, 1024])

In [7]:
from lightning.pytorch.callbacks import Callback
import datetime

EPOCHS_SAVE = 1


class TrainCallback(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.is_global_zero:
            if trainer.global_step == 0:
                timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
                log = open('./embvae-train_log.txt', 'a')

                log.write(f'NEW RUN {timestamp}\n{PARAMS}\n')

    def on_train_epoch_end(self, trainer, pl_module):
        if trainer.is_global_zero:  # logging & save state_dict
            if (trainer.current_epoch % EPOCHS_SAVE == 0):
                to_save_dict = pl_module.state_dict()

                try:
                    torch.save(
                        to_save_dict,
                        f'./embvae-{VOCAB_SIZE}x{EMBED_DIM}x{HIDDEN_DIM}-{EPOCH_BEGIN + 1 + trainer.current_epoch}.pth',
                    )
                except Exception as error:
                    print(error)

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        if trainer.is_global_zero:
            param_groups = trainer.optimizers[0].param_groups
            lr = param_groups[-1]['lr']

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if trainer.is_global_zero:
            param_groups = trainer.optimizers[0].param_groups
            lr = param_groups[-1]['lr']


In [8]:
trainer = pl.Trainer(
    devices='auto',
    max_epochs=100000,
    accelerator="auto",
    log_every_n_steps=100,
    callbacks=[
        TrainCallback()
    ],
    enable_checkpointing=False
)
trainer.fit(model=model, train_dataloaders=data_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type      | Params
--------------------------------------
0 | emb     | Embedding | 430 K 
1 | encoder | Encoder   | 14.4 M
2 | decoder | Decoder   | 14.4 M
3 | z_emb   | Linear    | 196 K 
4 | proj    | Linear    | 430 K 
5 | ln_out  | LayerNorm | 1.1 K 
--------------------------------------
29.8 M    Trainable params
0         Non-trainable params
29.8 M    Total params
119.317   Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

In [None]:
model.to(DEVICE)
model.eval()

sample = dataset[np.random.randint(0, len(dataset)-1)][0]

with torch.no_grad():
    output, emb_hat, emb, hidden, mean, log_var = model(sample.to(DEVICE))

emb_hat, emb