In [1]:
# Copyright (c) Xi Chen
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Borrowed from https://github.com/neocxi/pixelsnail-public and ported it to PyTorch

In [2]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [1]:
import pandas as pd
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from vq_vae_2_half_tmb import Model
from transformer import VQVAETransformer
from lr_scheduler import WarmupLinearLRSchedule
from torchvision import utils as vutils
#from utils import load_data, plot_images

import wandb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
wandb.init()
wandb.run.name = 'transformer_mid'
wandb.run.save()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mqja1998[0m. Use [1m`wandb login --relogin`[0m to force relogin




True

In [3]:
class MelData(Dataset):
    def __init__(self, file_path):
        self.data = []
        genre = ['classical', 'rock', 'electronic', 'pop']
        
        for g in genre:
            for i in range(1, 101):
                for j in range(5):
                    tmp_path = f'{file_path}/{g}/{i}-{j}.csv'
                    try:
                        self.data.append((pd.read_csv(tmp_path), g, i, j))
                    except FileNotFoundError:
                        print(f"{g}-{i}-{j} file is deleted")
                        continue
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        mel, g, i, j = self.data[idx]
        mel = torch.from_numpy(pd.get_dummies(mel).values)
        mel = mel.type(torch.cuda.FloatTensor)
        
        return (g, i, j), mel

class EmotionalData(Dataset):
    def __init__(self, file_path):
        self.data = pd.read_csv(file_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tmp_data = self.data.iloc[idx]
        genre, idx = tmp_data[0].split('_')
        emo = tmp_data[1:]
        return idx, genre, torch.FloatTensor(emo)

In [4]:
batch_size = 32

In [5]:
EMO_PATH = "./mean_data.csv"
MEL_ARR_PATH = "./split_mel_array"
SAVE_PATH = "./save_models"
    
mel_arr_data = MelData(MEL_ARR_PATH)
#emo_data = EmotionalData(EMO_PATH)


classical-16-0 file is deleted
classical-16-1 file is deleted
classical-16-2 file is deleted
classical-16-3 file is deleted
classical-16-4 file is deleted
classical-40-0 file is deleted
classical-40-1 file is deleted
classical-40-2 file is deleted
classical-40-3 file is deleted
classical-40-4 file is deleted
classical-57-0 file is deleted
classical-57-1 file is deleted
classical-57-2 file is deleted
classical-57-3 file is deleted
classical-57-4 file is deleted
classical-66-0 file is deleted
classical-66-1 file is deleted
classical-66-2 file is deleted
classical-66-3 file is deleted
classical-66-4 file is deleted
classical-73-0 file is deleted
classical-73-1 file is deleted
classical-73-2 file is deleted
classical-73-3 file is deleted
classical-73-4 file is deleted


In [6]:
mel_arr_data_loader = DataLoader(
        dataset=mel_arr_data, batch_size=batch_size)

#emo_data_loader = DataLoader(
#        dataset=emo_data, batch_size=batch_size, shuffle=True)

In [7]:
def scaled(x):
    return (x + 80.0) / (3.8147e-06 + 80)
def unscaled(x):
    return x * (3.8147e-06 + 80) - 80.0

In [12]:
# extract idices
# torch.Size([32, 10, 128]) torch.Size([32, 20, 256])
def extract_indice(mel_data, model):
    with torch.no_grad():
        for _, mel in mel_data:
            x = scaled(mel)
            x = x[:, :, :-2].unsqueeze(1).to(device)
            _, _, _, ids = model.encode(x)
            try:
                ids_t = torch.cat([ids_t, ids[0]], dim=0)
                ids_m = torch.cat([ids_m, ids[1]], dim=0)
                ids_b = torch.cat([ids_b, ids[2]], dim=0)
            except Exception as e:
                print(e)
                ids_t, ids_m, ids_b = ids[0].clone().detach(), ids[1].clone().detach(), ids[2].clone().detach()
    return ids_t, ids_m, ids_b
            

In [9]:
num_hiddens = 128 #128
num_residual_hiddens = 32
num_residual_layers = 4
embedding_dim = 16 #64
num_embeddings = 128 #512
commitment_cost = 0.25

In [10]:
vqvae = Model(num_hiddens=num_hiddens, 
                  num_residual_layers=num_residual_layers,
                  num_residual_hiddens=num_residual_hiddens,
                  num_embeddings=num_embeddings,
                  embedding_dim=embedding_dim, 
                  commitment_cost=commitment_cost).to(device)

score = 0.009567060507833958
MODEL_PATH = f'{SAVE_PATH}/vqvae2_tmb_split10-{score:.5f}_dict.pt'
vqvae.load_state_dict(torch.load(MODEL_PATH))

<All keys matched successfully>

In [13]:
ids_t, ids_m, ids_b = extract_indice(mel_arr_data_loader, vqvae)
print(ids_t.shape, ids_m.shape, ids_b.shape)

local variable 'ids_t' referenced before assignment
torch.Size([1975, 10, 16]) torch.Size([1975, 20, 32]) torch.Size([1975, 40, 64])


In [14]:
with torch.no_grad():
    quant_t = vqvae.vq_top.embed_code(ids_t[0]).unsqueeze(0)
    quant_t = quant_t.permute(0, 3, 1, 2).contiguous()
    quant_m = vqvae.vq_mid.embed_code(ids_m[0]).unsqueeze(0)
    quant_m = quant_m.permute(0, 3, 1, 2).contiguous()
    quant_b = vqvae.vq_bot.embed_code(ids_b[0]).unsqueeze(0)
    quant_b = quant_b.permute(0, 3, 1, 2).contiguous()
    upsample_t = vqvae.upsample_tx2(quant_t)
    upsample_m = vqvae.upsample_t(quant_m)
    quantized = torch.cat([upsample_t, upsample_m, quant_b], 1)
print(quantized.shape)

torch.Size([1, 48, 40, 64])


In [15]:
import importlib, transformer
importlib.reload(transformer)
from transformer import VQVAETransformer

In [16]:
class TrainTransformer:
    def __init__(self, args, data, lev):
        self.model = VQVAETransformer(args).to(device=args.device)
        self.optim = self.configure_optimizers()
        self.lr_schedule = WarmupLinearLRSchedule(
            optimizer=self.optim,
            init_lr=1e-6,
            peak_lr=args.learning_rate,
            end_lr=0.,
            warmup_epochs=10,
            epochs=args.epochs,
            current_step=args.start_from_epoch
        )
        #self.lr_schedule = AdafactorSchedule(self.optim, initial_lr=1e-2)

        if args.start_from_epoch > 1:
            self.model.load_checkpoint(args.start_from_epoch)
            print(f"Loaded Transformer from epoch {args.start_from_epoch}.")
        
        wandb.watch(self.model)
        self.lev = lev
        self.train(args, data)

    def train(self, args, data):
        train_dataset = data
        len_train_dataset = len(train_dataset)
        step = args.start_from_epoch * len_train_dataset
        for epoch in range(args.start_from_epoch+1, args.epochs+1):
            print(f"Epoch {epoch}:")
            with tqdm(range(len(train_dataset))) as pbar:
                for i, (ids_t, ids_m, ids_b) in zip(pbar, train_dataset):
                    if self.lev == 'top':
                        imgs = ids_t
                    elif self.lev == 'mid':
                        imgs = ids_m
                    elif self.lev == 'bot':
                        imgs = ids_b
                    imgs = imgs.to(device=args.device)
                    logits, target = self.model(imgs)
                    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)).requires_grad_(True)
                    loss.backward()
                    if step % args.accum_grad == 0:
                        self.optim.step()
                        self.optim.zero_grad()
                    step += 1
                    pbar.set_postfix(Transformer_Loss=np.round(loss.cpu().detach().numpy().item(), 4))
                    pbar.update(0)
                    
                    pbar.set_description(
                    (
                        f" lr {self.optim.param_groups[0]['lr']:.6f}\t"
                    )
                    )
                    wandb.log({
                        "Loss": loss,
                        "Learning rate": self.optim.param_groups[0]['lr']
                    })
                self.lr_schedule.step()
            try:
                log, sampled_imgs = self.model.log_images(imgs[0:1])
                vutils.save_image(sampled_imgs.add(1).mul(0.5), os.path.join("results", f"{epoch}.jpg"), nrow=4)
                #plot_images(log)
            except:
                pass
            #if epoch % args.ckpt_interval == 0:
            #    torch.save(self.model.state_dict(), os.path.join("checkpoints", f"transformer_epoch_{epoch}.pt"))
            torch.save(self.model.state_dict(), os.path.join(f"checkpoints", "transformer_current_{self.lev}.pt"))

    def configure_optimizers(self):
        # decay, no_decay = set(), set()
        # whitelist_weight_modules = (nn.Linear,)
        # blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)
        # for mn, m in self.model.transformer.named_modules():
        #     for pn, p in m.named_parameters():
        #         fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
        #
        #         if pn.endswith('bias'):
        #             no_decay.add(fpn)
        #
        #         elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
        #             decay.add(fpn)
        #
        #         elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
        #             no_decay.add(fpn)
        #
        # # no_decay.add('pos_emb')
        #
        # param_dict = {pn: p for pn, p in self.model.transformer.named_parameters()}
        #
        # optim_groups = [
        #     {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 4.5e-2},
        #     {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        # ]
        optimizer = torch.optim.Adam(self.model.transformer.parameters(), lr=1e-6, betas=(0.9, 0.96), weight_decay=4.5e-2)
        #optimizer = Adafactor(self.model.parameters(), lr=0.0, scale_parameter=True, relative_step=False)
        #optimizer = Adafactor(self.model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True)
        return optimizer


In [17]:
import argparse
parser = argparse.ArgumentParser(description="VQGAN")
parser.add_argument('--run-name', type=str, default=None)
parser.add_argument('--latent-dim', type=int, default=32, help='Latent dimension n_z.')
parser.add_argument('--image-size', type=int, default=256, help='Image height and width.)')
parser.add_argument('--num-codebook-vectors', type=int, default=8192, help='Number of codebook vectors.')
parser.add_argument('--beta', type=float, default=0.25, help='Commitment loss scalar.')
parser.add_argument('--image-channels', type=int, default=3, help='Number of channels of images.')
parser.add_argument('--dataset-path', type=str, default='./data', help='Path to data.')
parser.add_argument('--checkpoint-path', type=str, default='./checkpoints/last_ckpt.pt', help='Path to checkpoint.')
parser.add_argument('--device', type=str, default="cuda", help='Which device the training is on.')
parser.add_argument('--batch-size', type=int, default=10, help='Batch size for training.')
parser.add_argument('--accum-grad', type=int, default=10, help='Number for gradient accumulation.')
parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train.')
parser.add_argument('--start-from-epoch', type=int, default=1, help='Number of epochs to train.')
parser.add_argument('--ckpt-interval', type=int, default=100, help='Number of epochs to train.')
parser.add_argument('--learning-rate', type=float, default=1e-4, help='Learning rate.')

parser.add_argument('--sos-token', type=int, default=1025, help='Start of Sentence token.')

parser.add_argument('--n-layers', type=int, default=24, help='Number of layers of transformer.')
parser.add_argument('--dim', type=int, default=768, help='Dimension of transformer.')
parser.add_argument('--hidden-dim', type=int, default=3072, help='Dimension of transformer.')
parser.add_argument('--num-image-tokens', type=int, default=256, help='Number of image tokens.')

args = parser.parse_args(args=[])
args.run_name = "<name>"
args.dataset_path = r"C:\Users\dome\datasets\landscape"
args.checkpoint_path = r".\checkpoints"
args.n_layers = 24
args.dim = 512
args.hidden_dim = 3072
args.batch_size = 4
args.accum_grad = 25
args.epochs = 200

args.start_from_epoch = 0

args.num_codebook_vectors = 511
args.num_image_tokens = 20 * 64

wandb.config.update(args)

In [18]:
class IDsData(Dataset):
    def __init__(self, ids_t, ids_m, ids_b):
        self.ids_t = ids_t
        self.ids_m = ids_m
        self.ids_b = ids_b
        
    def __len__(self):
        return len(self.ids_t)
    
    def __getitem__(self, idx):
        return self.ids_t[idx], self.ids_m[idx], self.ids_b[idx]

In [19]:
length = len(ids_t)
ids_data = IDsData(ids_t.view(length, -1), ids_m.view(length, -1), ids_b.view(length, -1))

In [20]:
batch = 1
ids_loader = DataLoader(
        ids_data, batch_size=batch, shuffle=True, drop_last=False
    )

In [21]:
train_transformer = TrainTransformer(args, ids_loader, lev='mid')

Initializing Module Embedding.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Module Linear.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Module Linear.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Module Linear.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Module Linear.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Module Linear.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Module Linear.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Module Linear.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Module Linear.
Initializing Module NonDynamicallyQuantizableLinear.
Initializing Module Linear.
Initializing Mod

 lr 0.000001	: 100%|██████████| 1975/1975 [08:31<00:00,  3.86it/s, Transformer_Loss=5.25]


Epoch 2:


 lr 0.000001	: 100%|██████████| 1975/1975 [08:38<00:00,  3.81it/s, Transformer_Loss=4.09]


Epoch 3:


 lr 0.000011	: 100%|██████████| 1975/1975 [08:39<00:00,  3.80it/s, Transformer_Loss=3.85] 


Epoch 4:


 lr 0.000021	: 100%|██████████| 1975/1975 [08:39<00:00,  3.80it/s, Transformer_Loss=4.04] 


Epoch 5:


 lr 0.000031	: 100%|██████████| 1975/1975 [08:39<00:00,  3.80it/s, Transformer_Loss=2.01]  


Epoch 6:


 lr 0.000041	: 100%|██████████| 1975/1975 [08:39<00:00,  3.80it/s, Transformer_Loss=2.5]   


Epoch 7:


 lr 0.000050	: 100%|██████████| 1975/1975 [08:40<00:00,  3.80it/s, Transformer_Loss=0.311] 


Epoch 8:


 lr 0.000060	: 100%|██████████| 1975/1975 [08:39<00:00,  3.80it/s, Transformer_Loss=0.791] 


Epoch 9:


 lr 0.000070	: 100%|██████████| 1975/1975 [08:40<00:00,  3.79it/s, Transformer_Loss=0.0815]


Epoch 10:


 lr 0.000080	: 100%|██████████| 1975/1975 [08:39<00:00,  3.80it/s, Transformer_Loss=0.841] 


Epoch 11:


 lr 0.000090	: 100%|██████████| 1975/1975 [08:39<00:00,  3.80it/s, Transformer_Loss=0.719] 


Epoch 12:


 lr 0.000100	:  18%|█▊        | 352/1975 [01:33<07:09,  3.78it/s, Transformer_Loss=0.933] 


KeyboardInterrupt: 

In [None]:
torch.save(train_transformer.model.state_dict(), os.path.join("checkpoints", "transformer_current_mid.pt"))