# Model

In [1]:
from model import cc, Encoder, Decoder


# Dataset

In [2]:
from dataset import VCTK_Dataset
from dataset import *


# Train

In [5]:
import torch 
import torch.nn as nn

def _run_train(encoder, decoder, criterion, opt, dataloader):
        
    encoder.train()
    decoder.train()
    
    total_loss = 0
    for index, (person, spectrogram) in enumerate(dataloader):
        b = person.shape[0]
        
        opt.zero_grad()
    
        spectrogram = spectrogram.permute(0,2,1).cuda()
        latent = encoder(spectrogram)
        person = person.cuda()
        output = decoder(latent, person)
        
        loss = criterion(spectrogram, output)
        loss.backward()
        
        total_loss += loss.item()*b
        print("\t [{}/{}] train loss:{:.4f}".format(index+1,
                                              len(dataloader),
                                              loss.item()), 
                                          end='  \r')
        opt.step()
        
    return total_loss

def _run_eval(encoder, decoder, criterion, dataloader):
      
    encoder.eval()
    decoder.eval()
    
    with torch.no_grad():
        total_loss = 0
        for index, (person, spectrogram) in enumerate(dataloader):
            b = person.shape[0]

            spectrogram = spectrogram.permute(0,2,1).cuda()
            latent = encoder(spectrogram)
            person = person.cuda()
            output = decoder(latent, person)

            loss = criterion(spectrogram, output)

            total_loss += loss.item()*b
            print("\t [{}/{}] valid loss:{:.4f}".format(index+1,
                                                  len(dataloader),
                                                  loss.item()), 
                                              end='  \r')
                    
    return total_loss
    pass

def train(args, train_dataloader, valid_dataloader):
    
    encoder = cc(Encoder())
    decoder = cc(Decoder())
    
    criterion = torch.nn.L1Loss()
    opt = torch.optim.Adam(list(encoder.parameters())+list(decoder.parameters()), lr=args.lr)
    
    for epoch in range(args.epoch):
        print(f' Epoch {epoch}')
        
        loss = _run_train(encoder, decoder, criterion, opt, train_dataloader)
        print('\t [Info] Avg training loss:{:.5f}'.format(loss/len(train_dataloader.dataset)))

        loss = _run_eval(encoder, decoder, criterion, valid_dataloader)
        print('\t [Info] Avg valid loss:{:.5f}'.format(loss/len(valid_dataloader.dataset)))
        
        if True:
            save_path = "{}/epoch_{}_loss_{:.4f}".format(args.save_path,epoch,loss/len(valid_dataloader.dataset))
            torch.save({'state_dict': encoder.state_dict()},
                        f"{save_path}_enc_.pt")
            torch.save({'state_dict': decoder.state_dict()},
                        f"{save_path}_dec_.pt")
            print(f'\t [Info] save weights at {save_path}')
        print('-----------------------------------------------')


# Main

In [None]:
import os, warnings, argparse

def parse_args(string=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', default=1e-4,
                        type=float, help='leanring rate')
    parser.add_argument('--epoch', default=10,
                        type=int, help='epochs')
    parser.add_argument('--batch-size', default=32,
                        type=int, help='batch size')
    parser.add_argument('--num-workers', default=6,
                        type=int, help='dataloader num workers')
    parser.add_argument('--save-path', default='trained_model',
                        type=str, help='.pth model file save dir')
    
    args = parser.parse_args() if string is None else parser.parse_args(string)
    if not os.path.exists(args.save_path): os.makedirs(args.save_path)
    return args
    
if __name__=='__main__':
    args = parse_args('')
    
    os.environ['CUDA_VISIBLE_DEVICES'] = "0" #0:1080ti 1:1070
    warnings.filterwarnings("ignore")
    
    ## load dataset
    train_dataset = VCTK_Dataset('preprocess/vctk.h5', 'preprocess/sample_segments/train_samples', seg_len=128, mode='train')
    valid_dataset = VCTK_Dataset('preprocess/vctk.h5', 'preprocess/sample_segments/valid_samples', seg_len=128, mode='test')

    train_dataloader = DataLoader(train_dataset, 
                                  batch_size=args.batch_size,
                                  #num_workers = args.num_workers,
                                  shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, 
                                  batch_size=args.batch_size,)
                                  #num_workers = args.num_workers)
    
    ## train
    train(args, train_dataloader, valid_dataloader)


In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

for batch in train_dataloader:
    person = batch[0][0]
    spectrogram = batch[1][0].numpy()
    
    plt.figure(figsize=(16,9))
    plt.title(f"person:{person}")
    plt.imshow(spectrogram, cmap="hot")
    plt.show()
    
    input('resume')

# Test

In [None]:
import numpy as np

from preprocess.tacotron.norm_utils import spectrogram2wav, get_spectrograms
from scipy.io.wavfile import write



In [None]:

encoder = cc(Encoder())
encoder.load_state_dict(torch.load('trained_model/epoch_7_loss_0.0466_enc_.pt')['state_dict'])
decoder = cc(Decoder())
decoder.load_state_dict(torch.load('trained_model/epoch_7_loss_0.0466_dec_.pt')['state_dict'])

_, spectrogram = get_spectrograms('/media/D/DLHLP/hw2/Corpus/wav48/p2/p2_008.wav')

_input = np.expand_dims(spectrogram, axis=0)
_input = torch.tensor(_input).permute(0,2,1).cuda()

_latent = encoder(_input)

person = torch.tensor([0]).cuda()
output = decoder(_latent, person)
output = output.squeeze(axis=0).transpose(1,0).cpu().detach().numpy()

wav_data = spectrogram2wav(output)
write('result/abc.wav', 16000, data=wav_data)

In [None]:
spectrogram.shape, output.shape

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
    
plt.figure(figsize=(16,9))
plt.title(f"person: 0 ")
plt.imshow(spectrogram, cmap="hot")
plt.show()

plt.figure(figsize=(16,9))
plt.title(f"person: 1 ")
plt.imshow(output, cmap="hot")
plt.show()
    

# Interpolate

In [6]:
import numpy as np

from preprocess.tacotron.norm_utils import spectrogram2wav, get_spectrograms
from scipy.io.wavfile import write


In [17]:

encoder = cc(Encoder())
encoder.load_state_dict(torch.load('trained_model/epoch_7_loss_0.0466_enc_.pt')['state_dict'])
decoder = cc(Decoder())
decoder.load_state_dict(torch.load('trained_model/epoch_7_loss_0.0466_dec_.pt')['state_dict'])

_, spectrogram = get_spectrograms('/media/D/DLHLP/hw2/Corpus/wav48/p1/p1_334.wav')

_input = np.expand_dims(spectrogram, axis=0)
_input = torch.tensor(_input).permute(0,2,1).cuda()

_latent = encoder(_input)

p1 = torch.tensor([0]).cuda()
p2 = torch.tensor([1]).cuda()
output = decoder.interpolate(_latent, p1, p2)
#output = decoder(_latent, p1)
output = output.squeeze(axis=0).transpose(1,0).cpu().detach().numpy()

wav_data = spectrogram2wav(output)
write('result/abc.wav', 16000, data=wav_data)

In [None]:
torch.__version__