<a href="https://colab.research.google.com/github/testgithubprecious/Ml_projects/blob/main/WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install if not already: pip install torch matplotlib numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# ------------------------------
# Hyperparameters
# ------------------------------
sample_rate = 8000      # 8 kHz
duration = 1            # seconds
samples = sample_rate * duration
channels = 1
quantization_channels = 256


# ------------------------------
# One-hot encoding helper
# ------------------------------
def one_hot_encode(x, num_classes):
    """Convert integer tensor to one-hot encoding."""
    return F.one_hot(x, num_classes=num_classes).float()


# ------------------------------
# Dummy sine-wave training data
# ------------------------------
t = np.linspace(0, duration, samples, endpoint=False)
wave = 0.5 * np.sin(2 * np.pi * 440 * t)  # 440Hz tone (A4)
wave = ((wave + 1) / 2 * (quantization_channels - 1)).astype(np.int64)  # [0, 255]


# ------------------------------
# WaveNet model definition
# ------------------------------
class WaveNet(nn.Module):
    def __init__(self, residual_channels=32, dilation_depth=4, kernel_size=2):
        super().__init__()
        self.dilated_convs = nn.ModuleList()
        self.residuals = nn.ModuleList()

        for d in range(dilation_depth):
            dilation = 2 ** d
            # "same" padding: pad left = dilation, right = 0 (causal)
            self.dilated_convs.append(
                nn.Conv1d(
                    residual_channels,
                    2 * residual_channels,
                    kernel_size,
                    dilation=dilation,
                    padding=dilation
                )
            )
            self.residuals.append(nn.Conv1d(residual_channels, residual_channels, 1))

        self.input_conv = nn.Conv1d(quantization_channels, residual_channels, 1)
        self.output = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(residual_channels, quantization_channels, 1)
        )

    def forward(self, x):
        x = self.input_conv(x)
        for conv, res in zip(self.dilated_convs, self.residuals):
            out = conv(x)
            tanh_out, sigm_out = torch.chunk(out, 2, dim=1)
            z = torch.tanh(tanh_out) * torch.sigmoid(sigm_out)
            x = res(z) + x  # residual connection
        return self.output(x)


# ------------------------------
# Training setup (demo)
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WaveNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Prepare input and target tensors
wave_tensor = torch.tensor(wave, dtype=torch.long, device=device).unsqueeze(0)  # (1, T)
x_input = one_hot_encode(wave_tensor[:, :-1], quantization_channels).permute(0, 2, 1)  # (1, C, T-1)
y_target = wave_tensor[:, 1:]  # (1, T-1)

# ------------------------------
# One training step
# ------------------------------
model.train()
output = model(x_input)

# Reshape for CrossEntropyLoss: (N, C, T) -> (N*T, C)
loss = criterion(output.permute(0, 2, 1).reshape(-1, quantization_channels),
                 y_target.reshape(-1))

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"WaveNet Training Loss (1 step): {loss.item():.4f}")