In [133]:
import os
import requests
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import librosa
from librosa.filters import mel as librosa_mel_fn
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.nn.utils.rnn import pad_sequence
#print mel spectrogram using librosa
import librosa.display
import matplotlib.pyplot as plt

In [134]:
# Define the URL for the LJSpeech dataset
LJSPEECH_URL = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
LJSPEECH_PATH = "LJSpeech-1.1.tar.bz2"

In [135]:
# Download the LJSpeech dataset
response = requests.get(LJSPEECH_URL, stream=True)
file_size = int(response.headers['Content-Length'])

In [136]:
# with open(LJSPEECH_PATH, 'wb') as file:
#     for data in tqdm(response.iter_content(), total=file_size, unit="B", unit_scale=True, desc="Downloading LJSpeech"):
#         file.write(data)
# 
# # Unzip the downloaded dataset
# if LJSPEECH_PATH.endswith(".tar.bz2"):
#     import tarfile
#     with tarfile.open(LJSPEECH_PATH, 'r:bz2') as archive:
#         archive.extractall()
#         print("Extraction Complete!")
# else:
#     print("Unknown format: Cannot extract!")

In [301]:
config = {
    "mel": {
        "frame_length": 1024,
        "n_fft": 1024,
        "num_mels": 80,
        "sample_rate": 22050,
        "win_length": 1024,
        "hop_length": 256,
        "fmin": 0,
        "fmax": 8000,
    },
    "segment_length": 8192,
}

In [314]:
def mel_spec(y, config):
    # Get the mel-spectrogram
    mel_spec = librosa.feature.melspectrogram(
        y=y,
        sr=config["mel"]["sample_rate"],
        n_fft=config["mel"]["n_fft"],
        hop_length=config["mel"]["hop_length"],
        win_length=config["mel"]["win_length"],
        window="hann",
        center=True,
        pad_mode="edge",
        power=2.0,
        n_mels=config["mel"]["num_mels"],
        fmin=config["mel"]["fmin"],
        fmax=config["mel"]["fmax"],
    )
    log_mel_spec = librosa.power_to_db(mel_spec, ref=1.0, amin=1e-5, top_db=None)
    return log_mel_spec
    

In [319]:
mel_basis = {}
hann_window = {}

def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)

def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output

def mel_spectrogram(y, center=False):
    num_mels = config["mel"]["num_mels"]
    n_fft = config["mel"]["n_fft"]
    hop_size = config["mel"]["hop_length"]
    win_size = config["mel"]["win_length"]
    fmin = config["mel"]["fmin"]
    fmax = config["mel"]["fmax"]
    sampling_rate = config["mel"]["sample_rate"]
    if torch.min(y) < -1.:
        print('min value is ', torch.min(y))
    if torch.max(y) > 1.:
        print('max value is ', torch.max(y))

    global mel_basis, hann_window
    if fmax not in mel_basis:
        mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)

    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
    y = y.squeeze(1)

    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
                      center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=False)

    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))

    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
    spec = spectral_normalize_torch(spec)

    return spec



In [320]:
def load_wav(path, config):
    # Load the wav file
    wav, _ = librosa.load(path, sr=config["mel"]["sample_rate"])
    return wav

In [324]:
# test sample
wav = load_wav('LJSpeech-1.1/wavs/LJ001-0001.wav', config)
wav = wav[:8192]
wav = torch.from_numpy(wav).float()
wav = wav.unsqueeze(0)
mel = mel_spectrogram(wav)
print(wav.shape)
print(mel.shape)

torch.Size([1, 8192])
torch.Size([1, 80, 32])


Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at ../aten/src/ATen/native/SpectralOps.cpp:862.)
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]


In [296]:
def plot_mel_spectrogram(mel):
    mel = mel.numpy()
    plt.figure(figsize=(10, 4))
    plt.imshow(np.flip(mel, axis=0), cmap='inferno', aspect='auto')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Mel Spectrogram')
    plt.xlabel('Time Frame')
    plt.ylabel('Mel Frequency Bin')
    plt.tight_layout()
    plt.show()

In [341]:
root_dir = 'LJSpeech-1.1/wavs'
class MelDataset(Dataset):
    def __init__(self, root_dir):
        self.files_list = [f for f in os.listdir(root_dir) if f.endswith(".wav")]
        self.segment_size = config["segment_length"]
    
    def __len__(self):
        return len(self.files_list)
    
    def __getitem__(self, idx):
        audio_file = self.files_list[idx]
        wav = load_wav(os.path.join(root_dir, audio_file), config)
        # implemenmt same above code using numpy
        if wav.shape[0] >= self.segment_size:
            start = np.random.randint(0, wav.shape[0] - self.segment_size + 1, (1,)).item()
            wav = wav[start:start+self.segment_size]
        else:
            pad_amount = self.segment_size - wav.size(0)
            wav = np.pad(wav, (0, pad_amount), 'constant')
        

        # wav = wav[:,None,:]
        wav = wav[None,:]
        wav = torch.from_numpy(wav).float()
        # For mel, compute the corresponding segment
        mel = mel_spectrogram(wav)
        mel = mel.squeeze(0)
        return wav, mel


In [342]:
dataset = MelDataset(root_dir)

In [343]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(ResidualBlock, self).__init__()
        self.conv1 =weight_norm(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding='same', dilation=1)) 
        self.conv2 =weight_norm(nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding='same', dilation=1)) 
        
        self.conv3 =weight_norm(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding='same', dilation=3)) 
        self.conv4 = weight_norm(nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding='same', dilation=1))
        
        self.conv5 =weight_norm(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding='same', dilation=5)) 
        self.conv6 = weight_norm(nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding='same', dilation=1))
    
    def forward(self, x):
        residual = x
        x = F.leaky_relu(self.conv1(residual))
        x = F.leaky_relu(self.conv2(x))
        residual = x + residual
        x = F.leaky_relu(self.conv3(residual))
        x = F.leaky_relu(self.conv4(x))
        residual = x + residual
        x = F.leaky_relu(self.conv5(residual))
        x = F.leaky_relu(self.conv6(x))
        residual = x + residual
        return residual

In [344]:
class DescriminatorBlock(nn.Module):
    def __init__(self, in_channels=1):
        super(DescriminatorBlock,self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(in_channels=in_channels, out_channels=16, kernel_size=15, stride=1, padding=7))
        self.conv2 = weight_norm(nn.Conv1d(in_channels=16, out_channels=64, kernel_size=41, stride=4, padding=20, groups=4))
        
        self.conv3 = weight_norm(nn.Conv1d(in_channels=64, out_channels=256, kernel_size=41, stride=4, padding=20, groups=16))
        self.conv4 = weight_norm(nn.Conv1d(in_channels=256, out_channels=1024, kernel_size=41, stride=4, padding=20, groups=64))
        
        self.conv5 = weight_norm(nn.Conv1d(in_channels=1024, out_channels=1024, kernel_size=41, stride=4, padding=20, groups=256))
        self.conv6 = weight_norm(nn.Conv1d(in_channels=1024, out_channels=1024, kernel_size=5, stride=1, padding=2))
        self.conv7 = weight_norm(nn.Conv1d(in_channels=1024, out_channels=1, kernel_size=3, stride=1, padding=1))
    def forward(self, x):
        layer_1 = F.leaky_relu(self.conv1(x))
        # print("============ layer 1",layer_1.shape)
        layer_2 = F.leaky_relu(self.conv2(layer_1))
        # print("============ layer 2",layer_2.shape)
        layer_3 = F.leaky_relu(self.conv3(layer_2))
        # print("============ layer 3",layer_3.shape)
        layer_4 = F.leaky_relu(self.conv4(layer_3))
        # print("============ layer 4",layer_4.shape)
        layer_5 = F.leaky_relu(self.conv5(layer_4))
        # print("============ layer 5",layer_5.shape)
        layer_6 = F.leaky_relu(self.conv6(layer_5))
        # print("============ layer 6",layer_6.shape)
        result = self.conv7(layer_6)
        # print("============ result",result.shape)
        return [layer_1, layer_2, layer_3, layer_4, layer_5, layer_6, result]

In [371]:
class UpSampler(nn.Module):
    def __init__(self, up_sampling_factor, in_channels, out_channels, kernel_size):
        super(UpSampler, self).__init__()
        self.up_sampling_factor = up_sampling_factor
        self.conv_t = weight_norm(nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sampling_factor, padding=(kernel_size-up_sampling_factor)//2))
        self.res_block = ResidualBlock(out_channels, out_channels)
        
    def forward(self, x):
        x = F.leaky_relu(self.conv_t(x))
        x = self.res_block(x)
        x = F.leaky_relu(x)
        return x

In [372]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(in_channels=80, out_channels=512, kernel_size=7, stride=1, padding=3))
        self.up_sampler_1 = UpSampler(8, 512, 256, 16)
        self.up_sampler_2 = UpSampler(8, 256, 128,16)
        self.up_sampler_3 = UpSampler(2, 128, 64,4)
        self.up_sampler_4 = UpSampler(2, 64, 32,4)
        self.conv_out = weight_norm(nn.Conv1d(in_channels=32, out_channels=1, kernel_size=7, stride=1, padding=3))
    def forward(self,x):
        print("======= gen 1", x.shape)
        x = F.leaky_relu(self.conv1(x))
        print("======= gen 2", x.shape)
        x = self.up_sampler_1(x)
        print("======= gen 3", x.shape)
        x = self.up_sampler_2(x)
        print("======= gen 4", x.shape)
        x = self.up_sampler_3(x)
        print("======= gen 5", x.shape)
        x = self.up_sampler_4(x)
        print("======= gen 6", x.shape)
        x = self.conv_out(x)
        return x

In [373]:
class Descriminator(nn.Module):
    def __init__(self):
        super(Descriminator, self).__init__()
        self.desc_block_1 = DescriminatorBlock()
        self.desc_block_2 = DescriminatorBlock()
        self.desc_block_3 = DescriminatorBlock()
        self.avg_pool_1 = nn.AvgPool1d(kernel_size=2, stride=2)
        self.avg_pool_2 = nn.AvgPool1d(kernel_size=2, stride=2)
        
    def forward(self, x):
        # print("==========input", x.shape)
        out_1  = self.desc_block_1(x)
        # print("==== desc 1", out_1[-1].shape)
        x = self.avg_pool_1(x)
        # print("==== avg pool 1", x.shape)
        out_2 = self.desc_block_2(x)
        # print("==== desc 2", out_2[-1].shape)
        x = self.avg_pool_2(x)
        out_3 = self.desc_block_3(x)
        # print("==== desc 3", out_3[-1].shape)
        return [out_1, out_2, out_3]

In [374]:
def collate_fn(batch):
    # A batch will be a list of N tuples: [(wav_1, mel_1), (wav_2, mel_2), ...]
    wavs, mels = zip(*batch)
    # Pads sequence of variable length
    wavs_pad = pad_sequence(wavs, batch_first=True, padding_value=0)
    mels_pad = pad_sequence([mel.t() for mel in mels], batch_first=True).transpose(1, 2)
    
    return wavs_pad[:,None,:], mels_pad

In [375]:
# prepare data loader and split train and test ratio
dataset = MelDataset(root_dir='LJSpeech-1.1/test')
batch_size = 4
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [376]:
# load single sample and break
for i, data in enumerate(train_loader):
    wav, mel = data
    print("*********************")
    print(wav.shape)
    break

*********************
torch.Size([4, 1, 8192])


In [377]:
def generator_loss(fake_disc_output):
    loss = 0
    gen_losses = []
    for i in range(len(fake_disc_output)):
        result = fake_disc_output[i][-1]
        l = torch.mean((1-result)**2)
        gen_losses.append(l)
        loss += l
    return loss, gen_losses

In [378]:
def descriminator_loss(real_disc_output, fake_disc_output):
    loss = 0
    disc_losses = []
    for i in range(len(real_disc_output)):
        real_result = real_disc_output[i][-1]
        fake_result = fake_disc_output[i][-1]
        l = torch.mean((1-real_result)**2) + torch.mean((0-fake_result)**2)
        disc_losses.append(l)
        loss += l
    return loss, disc_losses

In [379]:
def feature_matching_loss(real_disc_output, fake_disc_output):
    loss = 0
    fm_losses = []
    assert len(real_disc_output) == len(fake_disc_output)
    # print(fake_disc_output.shape)
    len_except_last = len(real_disc_output)
    for i in range(len_except_last):
        for j in range(len(real_disc_output[i])-1):
            real_result = real_disc_output[i][j]
            fake_result = fake_disc_output[i][j]
            # l = torch.mean((real_result-fake_result)**2)
            #mae loss
            l = torch.mean(torch.abs(real_result-fake_result))
            fm_losses.append(l)
            loss += l
    return loss, fm_losses

In [380]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
generator = Generator().to(device)
descriminator = Descriminator().to(device)
generator_optimizer = Adam(generator.parameters(), lr=0.0001)
descriminator_optimizer = Adam(descriminator.parameters(), lr=0.0001)

In [381]:
def train():
    for epoch in range(100):
        for i, data in enumerate(train_loader):
            original_wav, mel = data
            original_wav = original_wav.to(device)
            mel = mel.to(device)
            gen_out = generator(mel)
            print("gen_out shape: ", gen_out.shape)
            real_disc_output = descriminator(original_wav)
            fake_disc_output = descriminator(gen_out)
            gen_loss, gen_losses = generator_loss(fake_disc_output)
            disc_loss, disc_losses = descriminator_loss(real_disc_output, fake_disc_output)
            fm_loss, fm_losses = feature_matching_loss(real_disc_output, fake_disc_output)
            total_gen_loss = gen_loss + 10*fm_loss
            generator_optimizer.zero_grad()
            total_gen_loss.backward(retain_graph=True)
            generator_optimizer.step()
            descriminator_optimizer.zero_grad()
            disc_loss.backward(retain_graph=True)
            descriminator_optimizer.step()
            print("epoch: {}, iteration: {}, gen_loss: {}, disc_loss: {}, fm_loss: {}".format(epoch, i, gen_loss, disc_loss, fm_loss))
            # print("gen_losses: {}, disc_losses: {}, fm_losses: {}".format(gen_losses, disc_losses, fm_losses))
            # break
        # break

In [382]:
train()

gen_out shape:  torch.Size([4, 1, 8192])
epoch: 0, iteration: 0, gen_loss: 3.0089802742004395, disc_loss: 3.009544849395752, fm_loss: 0.0986751988530159
gen_out shape:  torch.Size([4, 1, 8192])
epoch: 0, iteration: 1, gen_loss: 2.5753393173217773, disc_loss: 2.592524290084839, fm_loss: 0.1021510437130928
gen_out shape:  torch.Size([4, 1, 8192])
epoch: 0, iteration: 2, gen_loss: 2.1884853839874268, disc_loss: 2.2558443546295166, fm_loss: 0.11406855285167694
gen_out shape:  torch.Size([4, 1, 8192])
epoch: 1, iteration: 0, gen_loss: 1.8323485851287842, disc_loss: 1.9847863912582397, fm_loss: 0.0710679143667221
gen_out shape:  torch.Size([4, 1, 8192])
epoch: 1, iteration: 1, gen_loss: 1.4940266609191895, disc_loss: 1.7735192775726318, fm_loss: 0.048993177711963654
gen_out shape:  torch.Size([4, 1, 8192])
epoch: 1, iteration: 2, gen_loss: 1.1810543537139893, disc_loss: 1.6325737237930298, fm_loss: 0.08496694266796112
gen_out shape:  torch.Size([4, 1, 8192])
epoch: 2, iteration: 0, gen_loss:

In [249]:
8192/8

1024.0

In [250]:
conv1 = weight_norm(nn.Conv1d(in_channels=80, out_channels=512, kernel_size=7, stride=1, padding=3))
rand_input = torch.randn(1, 80, 8192)