<a href="https://colab.research.google.com/github/syedmahmoodiagents/Speech/blob/main/SimpleWaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import librosa
import soundfile as sf

In [None]:

x_small = torch.tensor([1e-15, 0.0, 1.0])


y_loge = torch.log(x_small)
print(f"log_e result: {y_loge}")

y_log_direct = torch.log(1 + x_small)
print(f"log(1 + x) result: {y_log_direct}")

y_log1p = torch.log1p(x_small)
print(f"log1p result: {y_log1p}")

log_e result: tensor([-34.5388,     -inf,   0.0000])
log(1 + x) result: tensor([0.0000, 0.0000, 0.6931])
log1p result: tensor([1.0000e-15, 0.0000e+00, 6.9315e-01])


In [None]:
torch.sign(torch.tensor([-2.3, 1.8, -0.001, 0.7, -0.8, 9.8]))

tensor([-1.,  1., -1.,  1., -1.,  1.])

In [None]:
def mu_law_encode(x, mu=256):
    x = torch.clamp(x, -1.0, 1.0) # Clamp input to [-1, 1] range
    # mu_val_for_formula = mu - 1
    # Ensure the scalar 'mu' is a tensor of the same dtype as 'x' for log1p
    mu_tensor = torch.tensor(mu - 1, dtype=x.dtype)
    return torch.sign(x) * torch.log1p((mu - 1) * torch.abs(x)) / torch.log1p(mu_tensor)

def quantize(x, mu=256):
    x = mu_law_encode(x, mu)
    return ((x + 1) / 2 * (mu - 1)).long()

In [None]:
import torch.nn as nn

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super().__init__()
        self.padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_channels, out_channels,
            kernel_size,
            padding=self.padding,
            dilation=dilation
        )

    def forward(self, x):
        out = self.conv(x)
        return out[:, :, :-self.padding]  # remove future leakage


In [None]:
class WaveNetBlock(nn.Module):
    def __init__(self, channels, kernel_size, dilation):
        super().__init__()

        self.conv = CausalConv1d(
            channels, 2 * channels, kernel_size, dilation
        )

        self.residual = nn.Conv1d(channels, channels, kernel_size=1)
        self.skip = nn.Conv1d(channels, channels, kernel_size=1)

    def forward(self, x):
        out = self.conv(x)

        tanh, sigmoid = out.chunk(2, dim=1)
        gated = torch.tanh(tanh) * torch.sigmoid(sigmoid)

        residual = self.residual(gated)
        skip = self.skip(gated)

        return x + residual, skip


In [None]:
class WaveNet(nn.Module):
    def __init__(self, in_channels=1, channels=64, kernel_size=2, num_blocks=3, layers_per_block=10, num_classes=256):
        super().__init__()

        self.input_conv = nn.Conv1d(in_channels, channels, kernel_size=1)

        self.blocks = nn.ModuleList()
        dilations = []

        for _ in range(num_blocks):
            for i in range(layers_per_block):
                dilation = 2 ** i
                self.blocks.append(
                    WaveNetBlock(channels, kernel_size, dilation)
                )
                dilations.append(dilation)

        self.relu = nn.ReLU()
        self.output_conv = nn.Sequential(
            nn.Conv1d(channels, channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(channels, num_classes, kernel_size=1)
        )

    def forward(self, x):
        x = self.input_conv(x)

        skip_connections = []

        for block in self.blocks:
            x, skip = block(x)
            skip_connections.append(skip)

        out = sum(skip_connections)
        out = self.relu(out)
        out = self.output_conv(out)

        return out


In [None]:

dummy_audio_path = 'dummy_audio.wav'
sr = 16000
duration = 5 # seconds
y_dummy = np.random.uniform(low=-1.0, high=1.0, size=sr*duration).astype(np.float32)
sf.write(dummy_audio_path, y_dummy, sr)

y, sr = librosa.load(dummy_audio_path, sr=16000) # Load at a specific sample rate, e.g., 16 kHz

In [None]:

audio_input_tensor = torch.from_numpy(y).float()
audio_input_tensor = audio_input_tensor.unsqueeze(0).unsqueeze(0) # [Batch=1, channel=1 (mono), length of vector]

In [None]:

clipped_audio_input_tensor = audio_input_tensor[:, :, :16000]

print(f"Clipped audio input tensor shape: {clipped_audio_input_tensor.shape}")


Clipped audio input tensor shape: torch.Size([1, 1, 16000])


In [None]:
model = WaveNet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [None]:
x = clipped_audio_input_tensor
y = quantize(x[:, :, 1:])      # next sample
x = x[:, :, :-1]

In [None]:
logits = model(x)
loss = criterion(logits, y.squeeze(1))

loss.backward()
optimizer.step()

print(loss)

tensor(5.5642, grad_fn=<NllLoss2DBackward0>)
