<a href="https://colab.research.google.com/github/syedmahmoodiagents/Speech/blob/main/WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import numpy as np

In [None]:
AUDIO_PATH = "dummy_audio.wav"
SAMPLE_RATE = 16000
N_MELS = 80
MU = 256

In [None]:
def mu_law_encode(x, mu=256):
    mu = mu - 1
    return torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(torch.tensor(mu))

def quantize(x, mu=256):
    x = mu_law_encode(x, mu)
    return ((x + 1) / 2 * (mu - 1)).long()

In [None]:

class CausalConv1d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, dilation):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_ch, out_ch,
            kernel_size,
            padding=self.pad,
            dilation=dilation
        )

    def forward(self, x):
        x = self.conv(x)
        return x[:, :, :-self.pad]

In [None]:

class WaveNetBlock(nn.Module):
    def __init__(self, channels, kernel_size, dilation):
        super().__init__()
        self.dilated = CausalConv1d(
            channels, 2 * channels, kernel_size, dilation
        )
        self.residual = nn.Conv1d(channels, channels, 1)
        self.skip = nn.Conv1d(channels, channels, 1)

    def forward(self, x, cond):
        out = self.dilated(x) + cond
        tanh, sig = out.chunk(2, dim=1)
        gated = torch.tanh(tanh) * torch.sigmoid(sig)
        return x + self.residual(gated), self.skip(gated)


In [None]:

class WaveNet(nn.Module):
    def __init__(
        self, residual_channels=64, kernel_size=2,
        dilations=[1,2,4,8,16,32,64,128],
        mel_channels=80, num_classes=256, num_speakers=1
    ):
        super().__init__()

        self.input = nn.Conv1d(1, residual_channels, 1)

        self.local_cond = nn.Conv1d(
            mel_channels, 2 * residual_channels, 1
        )

        self.global_embed = nn.Embedding(
            num_speakers, 2 * residual_channels
        )

        self.blocks = nn.ModuleList([
            WaveNetBlock(residual_channels, kernel_size, d)
            for d in dilations
        ])

        self.output = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(residual_channels, residual_channels, 1),
            nn.ReLU(),
            nn.Conv1d(residual_channels, num_classes, 1)
        )

        self.receptive_field = 1 + (kernel_size - 1) * sum(dilations)

    def forward(self, x, mel, speaker_id):
        x = self.input(x)
        cond = self.local_cond(mel)
        g = self.global_embed(speaker_id).unsqueeze(-1)
        cond = cond + g

        skips = []
        for block in self.blocks:
            x, skip = block(x, cond)
            skips.append(skip)

        return self.output(sum(skips))


In [None]:
audio, _ = librosa.load(AUDIO_PATH, sr=SAMPLE_RATE)
audio = audio / np.max(np.abs(audio))
audio = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

In [None]:
mel = librosa.feature.melspectrogram( y=audio.numpy().squeeze(), sr=SAMPLE_RATE,n_mels=N_MELS,hop_length=1 )
mel = librosa.power_to_db(mel)
mel = torch.tensor(mel, dtype=torch.float32).unsqueeze(0)

In [None]:
T = min(audio.shape[-1], mel.shape[-1])
audio = audio[:, :, :T]
mel = mel[:, :, :T]

In [None]:
model = WaveNet()
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
criterion = nn.CrossEntropyLoss()
speaker_id = torch.tensor([0])

In [None]:
x = audio[:, :, :-1]
y = quantize(audio[:, :, 1:])

In [None]:
logits = model(x, mel[:, :, :-1], speaker_id)
loss = criterion(logits, y.squeeze(1))

In [None]:
loss.backward()
optimizer.step()

print("Training loss:", loss.item())
print("Receptive field:", model.receptive_field)

Training loss: 5.58225679397583
Receptive field: 256


In [None]:

@torch.no_grad()
def generate(model, mel, speaker_id):
    model.eval()
    length = mel.shape[-1]
    x = torch.zeros(1, 1, length)

    for t in range(model.receptive_field, length - 1):
        logits = model(x[:, :, :t], mel[:, :, :t], speaker_id)
        probs = F.softmax(logits[:, :, -1], dim=-1)
        sample = torch.multinomial(probs, 1)
        x[0, 0, t+1] = sample.float() / 128.0 - 1.0

    return x

In [None]:
generated_audio = generate(model, mel, speaker_id)
print("Generated waveform shape:", generated_audio.shape)


Generated waveform shape: torch.Size([1, 1, 16000])


In [None]:
generated_audio

tensor([[[ 0.0000,  0.0000,  0.0000,  ..., -0.5703,  0.7656,  0.2500]]])

In [None]:
from IPython.display import Audio
import numpy as np

In [None]:
audio_np = generated_audio.squeeze().numpy()
Audio(audio_np, rate=SAMPLE_RATE)