In [1]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
from preprocessing import convert_spectrograms, convert_tensor
from model_ae import Encoder, Decoder, Discriminator
from utils.optimization import WarmupLinearSchedule

In [2]:
model = 'aae'
conv_dim = '1d'
batch_size = 128
learning_rate = 0.001
num_epochs = 100

use_warmup = True
gradient_accumulation_steps = 1
warmup_proportion = 0.1

multi_gpu = True

cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_z = 800 if (conv_dim == '1d') else 1408
if(multi_gpu):
    batch_size = batch_size * torch.cuda.device_count()

In [3]:
sample_data_repo = os.path.join('.', 'wav_data', 'pretrain')

samples_data = glob.glob(os.path.join(sample_data_repo, '**', '*wav'), recursive=True)
samples_data = sorted(samples_data)
len(samples_data)

1440

In [4]:
np.random.seed(42)
idx = np.random.permutation(len(samples_data))
train_idx = idx[:int(len(samples_data)*0.8)]
eval_idx = idx[int(len(samples_data)*0.8):]

In [5]:
train_samples = list(np.array(samples_data)[train_idx])
eval_samples = list(np.array(samples_data)[eval_idx])

In [6]:
len(train_samples), len(eval_samples)

(1152, 288)

In [7]:
X_train = convert_spectrograms(train_samples, conv_dim=conv_dim)
X_eval = convert_spectrograms(eval_samples, conv_dim=conv_dim)

100%|██████████| 1152/1152 [03:24<00:00,  5.63it/s]
100%|██████████| 288/288 [00:51<00:00,  5.58it/s]


In [8]:
X_train = convert_tensor(X_train, device=device)
X_eval = convert_tensor(X_eval, device=device)

In [9]:
train_dataloader = DataLoader(X_train, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
eval_dataloader = DataLoader(X_eval, batch_size=batch_size, num_workers=0, drop_last=True)

In [10]:
if(multi_gpu == True):
    encoder = Encoder(conv_dim=conv_dim)
    decoder = Decoder(conv_dim=conv_dim)
    encoder = torch.nn.DataParallel(encoder).cuda()
    decoder = torch.nn.DataParallel(decoder).cuda()
    if(model == 'aae'):
        discriminator = Discriminator(conv_dim=conv_dim)
        discriminator = torch.nn.DataParallel(discriminator).cuda()
else:
    encoder = Encoder(conv_dim=conv_dim).to(device)
    decoder = Decoder(conv_dim=conv_dim).to(device)
    if(model == 'aae'):
        discriminator = Discriminator(conv_dim=conv_dim).to(device)

In [11]:
def reset_grad():
    encoder.zero_grad()
    decoder.zero_grad()
    if(model == 'aae'):
        discriminator.zero_grad()

In [12]:
loss_func = nn.MSELoss()
enc_opt = optim.Adam(encoder.parameters(), lr=learning_rate)
dec_opt = optim.Adam(decoder.parameters(), lr=learning_rate)
if(model == 'aae'):
    disc_opt = optim.Adam(discriminator.parameters(), lr=learning_rate*0.1)

if(use_warmup == True):
    t_total = len(train_dataloader) // gradient_accumulation_steps * num_epochs
    enc_scheduler = WarmupLinearSchedule(enc_opt, warmup_steps=t_total * warmup_proportion, t_total=t_total)
    dec_scheduler = WarmupLinearSchedule(dec_opt, warmup_steps=t_total * warmup_proportion, t_total=t_total)
    disc_scheduler = WarmupLinearSchedule(disc_opt, warmup_steps=t_total * warmup_proportion, t_total=t_total)

In [13]:
def train(train_dataloader, eval_dataloader, epochs):
    for epoch in range(epochs):
        ### Train step
        tr_recon_loss = 0
        nb_train_steps = 0

        for X_batch in train_dataloader:
            ## Reconstruction
            X = Variable(X_batch)
            if cuda:
                X = X.cuda()

            z_sample = encoder(X)
            X_sample = decoder(z_sample)
            
            recon_loss = loss_func(X_sample, X)

            recon_loss.backward()
            dec_opt.step()
            enc_opt.step()
            reset_grad()
            
            tr_recon_loss += recon_loss.mean().item()
            nb_train_steps += 1

            if(model == 'aae'):
                ## Discriminator
                for _ in range(5):
                    z_fake = Variable(torch.randn(batch_size, n_z))
                    if cuda:
                        z_fake = z_fake.cuda()

                    z_real = encoder(X).view(batch_size, -1)

                    D_fake = discriminator(z_fake)
                    D_real = discriminator(z_real)

                    D_loss = -(torch.mean(D_real) - torch.mean(D_fake))

                    D_loss.backward()
                    disc_opt.step()

                    # Weight clipping
                    for p in discriminator.parameters():
                        p.data.clamp_(-0.01, 0.01)

                    reset_grad()

                ## Generator
                z_real = encoder(X).view(batch_size, -1)
                D_real = discriminator(z_real)

                G_loss = -torch.mean(D_real)

                G_loss.backward()
                enc_opt.step()
                reset_grad()
                
                if(use_warmup == True):
                    enc_scheduler.step()
                    dec_scheduler.step()
                    disc_scheduler.step()
                

        tr_recon_loss = tr_recon_loss / nb_train_steps
            
        ### Evaluate step
        encoder.eval()
        decoder.eval()
        if(model == 'aae'):
            discriminator.eval()
        eval_recon_loss = 0
        nb_eval_steps = 0
        
        for X_batch in eval_dataloader:
            with torch.no_grad():
                z_sample = encoder(X_batch)
                X_sample = decoder(z_sample)
            
            tmp_eval_loss = loss_func(X_sample, X_batch)
            eval_recon_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            
        eval_recon_loss = eval_recon_loss / nb_eval_steps
            
        
        for param_group in enc_opt.param_groups:
            lr = param_group['lr']
        print('epoch: {:3d},    lr={:6f},    loss={:5f},    eval_loss={:5f}'
              .format(epoch+1, lr, tr_recon_loss, eval_recon_loss))

In [14]:
train(train_dataloader, eval_dataloader, num_epochs)

epoch:   1,    lr=0.000100,    loss=0.553253,    eval_loss=0.124530
epoch:   2,    lr=0.000200,    loss=0.124531,    eval_loss=0.122389
epoch:   3,    lr=0.000300,    loss=0.121527,    eval_loss=0.118092
epoch:   4,    lr=0.000400,    loss=0.116297,    eval_loss=0.112291
epoch:   5,    lr=0.000500,    loss=0.110005,    eval_loss=0.106661
epoch:   6,    lr=0.000600,    loss=0.104989,    eval_loss=0.100832
epoch:   7,    lr=0.000700,    loss=0.098941,    eval_loss=0.094782
epoch:   8,    lr=0.000800,    loss=0.092672,    eval_loss=0.088312
epoch:   9,    lr=0.000900,    loss=0.086230,    eval_loss=0.081469
epoch:  10,    lr=0.001000,    loss=0.079145,    eval_loss=0.074454
epoch:  11,    lr=0.000989,    loss=0.072399,    eval_loss=0.068063
epoch:  12,    lr=0.000978,    loss=0.066531,    eval_loss=0.063028
epoch:  13,    lr=0.000967,    loss=0.062055,    eval_loss=0.059557
epoch:  14,    lr=0.000956,    loss=0.059633,    eval_loss=0.057560
epoch:  15,    lr=0.000944,    loss=0.057972,   