In [1]:

# #-------------------------------------------------modified..................

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# ------------------------
# Encoder / Decoder
# -------------------------
class Encoder(nn.Module):
    def __init__(self, in_channels=1, z_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, z_dim, kernel_size=4, stride=2, padding=1)
        )

    def forward(self, x):
        return self.net(x)  # (B, D, T_down)

class Decoder(nn.Module):
    def __init__(self, out_channels=1, z_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose1d(z_dim, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(64, out_channels, kernel_size=4, stride=2, padding=1)
        )

    def forward(self, q):
        return self.net(q)

# -------------------------
# Vector Quantizer EMA
# -------------------------
# class VectorQuantizerEMA(nn.Module):
#     def __init__(self, num_embeddings=1024, embedding_dim=256, commitment_cost=0.25, decay=0.99, eps=1e-5):
#         super().__init__()
#         self.num_embeddings = num_embeddings
#         self.embedding_dim = embedding_dim
#         self.commitment_cost = commitment_cost
#         self.decay = decay
#         self.eps = eps

#         embed = torch.randn(embedding_dim, num_embeddings)
#         self.register_buffer('embedding', embed)  # (D, K)
#         self.register_buffer('cluster_size', torch.zeros(num_embeddings))
#         self.register_buffer('embed_avg', embed.clone())

#     def forward(self, z):
#         B, D, T = z.shape
#         flat = z.permute(0,2,1).contiguous().view(-1, D)  # (B*T, D)
#         emb_t = self.embedding.t()  # (K, D)

#         # Compute distances and nearest embeddings
#         distances = flat.pow(2).sum(1, keepdim=True) - 2 * flat @ emb_t.t() + emb_t.pow(2).sum(1).unsqueeze(0)
#         encoding_indices = torch.argmin(distances, dim=1)
#         encodings = F.one_hot(encoding_indices, num_classes=self.num_embeddings).type(flat.dtype)
#         quantized = (encodings @ self.embedding.t()).view(B, T, D).permute(0,2,1).contiguous()

#         # EMA updates (training only)
#         if self.training:
#             with torch.no_grad():  # do EMA updates without tracking autograd
#                 n = encodings.sum(0).detach()
#                 self.cluster_size.mul_(self.decay).add_(n, alpha=1 - self.decay)
#                 embed_sum = flat.t() @ encodings
#                 self.embed_avg.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
#                 n = self.cluster_size + self.eps
#                 self.embedding.copy_(self.embed_avg / n.unsqueeze(0))
#         # Compute VQ losses
#         e_latent_loss = F.mse_loss(quantized.detach(), z)
#         q_latent_loss = F.mse_loss(quantized, z.detach())
#         loss = q_latent_loss + self.commitment_cost * e_latent_loss

#         # Straight-through estimator
#         quantized = z + (quantized - z).detach()
#         indices = encoding_indices.view(B, T)
#         return quantized, loss, indices

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings=1024, embedding_dim=256,
                 commitment_cost=0.25, decay=0.99, eps=1e-5):
        super().__init__()

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.eps = eps

        # Much better initialization
        embed = torch.randn(num_embeddings, embedding_dim) * 0.1
        self.register_buffer("embedding", embed)            # (K, D)

        self.register_buffer("cluster_size", torch.zeros(num_embeddings))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, z):
        """
        z: (B, D, T)
        """
        B, D, T = z.shape
        flat = z.permute(0, 2, 1).contiguous().view(-1, D)  # (B*T, D)

        # ---------- Correct pairwise L2 distances ----------
        emb = self.embedding                           # (K, D)
        dist = (flat.pow(2).sum(dim=1, keepdim=True)
                + emb.pow(2).sum(dim=1)
                - 2 * flat @ emb.t())                  # (B*T, K)

        # ---------- Find nearest embedding ----------
        encoding_indices = torch.argmin(dist, dim=1)
        encodings = F.one_hot(encoding_indices,
                               self.num_embeddings).type(flat.dtype)

        quantized = encodings @ emb        # (B*T, D)
        quantized = quantized.view(B, T, D).permute(0, 2, 1).contiguous()

        # ---------- EMA update ----------
        if self.training:
            with torch.no_grad():

                # cluster size update with smoothing
                cluster_sum = encodings.sum(0)
                self.cluster_size.mul_(self.decay).add_(
                    cluster_sum, alpha=1 - self.decay)

                # embedding sum
                embed_sum = flat.t() @ encodings
                self.embed_avg.mul_(self.decay).add_(
                    embed_sum.t(), alpha=1 - self.decay)

                # Laplace smoothing to prevent collapse
                n = (self.cluster_size + self.eps)
                embed_normalized = self.embed_avg / n.unsqueeze(1)

                self.embedding.copy_(embed_normalized)

        # ---------- Losses ----------
        e_loss = F.mse_loss(z.detach(), quantized)
        q_loss = F.mse_loss(z, quantized.detach())

        vq_loss = q_loss + self.commitment_cost * e_loss

        # Straight-through estimator
        quantized = z + (quantized - z).detach()

        indices = encoding_indices.view(B, T)
        return quantized, vq_loss, indices

# ------------------------
# VQ-VAE wrapper
# -------------------------
class VQVAE(nn.Module):
    def __init__(self, z_dim=256, num_embeddings=1024):
        super().__init__()
        self.encoder = Encoder(1, z_dim)
        self.quantizer = VectorQuantizerEMA(num_embeddings, z_dim)
        self.decoder = Decoder(1, z_dim)

    def forward(self, x):
        z = self.encoder(x)
        quantized, vq_loss, indices = self.quantizer(z)
        x_rec = self.decoder(quantized)
        return x_rec, vq_loss, indices

# ------------------------
# Explicit training step
# -------------------------
# def training_step(model, optimizer, batch_waveform):
#     model.train()
#     optimizer.zero_grad()
#     x_rec, vq_loss, _ = model(batch_waveform)
#     recon_loss = F.l1_loss(x_rec, batch_waveform)
#     loss = recon_loss + vq_loss
#     loss.backward()
#     optimizer.step()
#     return loss.item(), recon_loss.item(), vq_loss.item()

# ------------------------
# Eval helpers (no gradient)
# -------------------------
@torch.no_grad()
def tokenize_audio(model, waveform):
    model.eval()
    z = model.encoder(waveform)
    _, _, indices = model.quantizer(z)
    return indices.cpu().numpy()

@torch.no_grad()
def reconstruct_from_indices(model, indices):
    if isinstance(indices, np.ndarray):
        indices = torch.from_numpy(indices).to(model.quantizer.embedding.device)
    B, T = indices.shape
    D = model.quantizer.embedding.shape[0]
    emb = model.quantizer.embedding.t()
    flat = emb[indices.view(-1)]
    q = flat.view(B, T, D).permute(0,2,1).contiguous()
    x_rec = model.decoder(q)
    return x_rec


import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import torch.optim as optim

# ------------------------
# Assume your VQVAE, Encoder, Decoder, VectorQuantizerEMA are already defined
# ------------------------

# Create dummy dataset: 10 audio samples, each length 1024
B = 10      # batch size
T = 1024    # audio length
dummy_audio = torch.randn(B, 1, T)  # (B,1,T)

# Wrap in a DataLoader
dataset = TensorDataset(dummy_audio)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

# Create model
z_dim = 64           # smaller for testing
num_embeddings = 32  # smaller for testing
model = VQVAE(z_dim=z_dim, num_embeddings=num_embeddings)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# ------------------------
# Single training step on dummy data
# ------------------------
for batch in dataloader:
    x = batch[0].to(device)  # (B,1,T)
    
    model.train()
    optimizer.zero_grad()
    
    x_rec, vq_loss, indices = model(x)
    
    # Reconstruction loss (L1)
    recon_loss = F.l1_loss(x_rec, x)
    
    # Total loss
    loss = recon_loss + vq_loss
    print(loss)
    # Backprop
    loss.backward()
    optimizer.step()
    
    # print("Batch Loss:", loss.item())
    # print("Reconstruction Loss:", recon_loss.item())
    # print("VQ Loss:", vq_loss.item())
    # print("Token shape:", indices.shape)




tensor(0.8597, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.8286, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.8114, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.8041, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.8006, device='cuda:0', grad_fn=<AddBackward0>)


In [2]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, files, segment_length=16000):
        self.files = files
        self.segment_length = segment_length
        self.sr = 16000

    def __getitem__(self, idx):
        wav, sr = torchaudio.load(self.files[idx])
        wav = torchaudio.functional.resample(wav, sr, self.sr)
        wav = wav.mean(dim=0)   # mono
        wav = wav / wav.abs().max()  # normalize
        
        # random crop
        if wav.shape[0] >= self.segment_length:
            start = torch.randint(0, wav.shape[0] - self.segment_length, (1,))
            wav = wav[start:start+self.segment_length]
        else:
            wav = F.pad(wav, (0, self.segment_length - wav.shape[0]))
        
        return wav.unsqueeze(0)

    def __len__(self):
        return len(self.files)


In [3]:

import os
import torch
import torchaudio
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
AUDIO_DIR = "/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original/blues"

# collect audio file paths
audio_files = [
    os.path.join(AUDIO_DIR, f)
    for f in os.listdir(AUDIO_DIR)
    if f.lower().endswith((".wav", ".mp3", ".flac"))
]
print("Found", len(audio_files), "audio files")

audio_files = []
for root, dirs, files in os.walk(AUDIO_DIR):
    for f in files:
        if f.lower().endswith((".wav", ".mp3", ".flac")):
            audio_files.append(os.path.join(root, f))

dataset = AudioDataset(audio_files, segment_length=16000)

loader = DataLoader(dataset, batch_size=8, shuffle=True)


Found 100 audio files


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VQVAE(z_dim=256, num_embeddings=1024).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)


In [9]:

import torch
import torch.nn.functional as F

def multiscale_stft_loss(x, x_rec):
    """
    x, x_rec: (B, 1, T)
    Returns scalar STFT loss
    """
    losses = []
    scales = [
        (2048, 512, 2048),
        (1024, 256, 1024),
        (512, 128, 512),
    ]

    x = x[:, 0]       # (B, T)
    x_rec = x_rec[:, 0]

    for n_fft, hop, win in scales:
        window = torch.hann_window(win).to(x.device).detach()

        X = torch.stft(
            x,
            n_fft=n_fft,
            hop_length=hop,
            win_length=win,
            window=window,
            return_complex=True
        )

        Xr = torch.stft(
            x_rec,
            n_fft=n_fft,
            hop_length=hop,
            win_length=win,
            window=window,
            return_complex=True
        )

        losses.append((X - Xr).abs().mean())

    return sum(losses)

# import torch
# import torch.nn.functional as F

# class multiscale_stft_loss(torch.nn.Module):
#     def __init__(self, scales=None):
#         super().__init__()

#         if scales is None:
#             scales = [
#                 (2048, 512, 2048),
#                 (1024, 256, 1024),
#                 (512, 128, 512),
#             ]

#         self.scales = scales
#         self.windows = {
#             win: torch.hann_window(win) for (_, _, win) in scales
#         }

#     def stft(self, x, n_fft, hop, win, device):
#         window = self.windows[win].to(device)
#         return torch.stft(
#             x,
#             n_fft=n_fft,
#             hop_length=hop,
#             win_length=win,
#             window=window,
#             return_complex=True,
#         )

#     def forward(self, x, x_rec):
#         """
#         x, x_rec: (B,1,T)
#         """
#         x = x[:, 0]
#         x_rec = x_rec[:, 0]

#         sc_losses = []
#         mag_losses = []

#         for n_fft, hop, win in self.scales:
#             X = self.stft(x, n_fft, hop, win, x.device)
#             Xr = self.stft(x_rec, n_fft, hop, win, x.device)

#             # magnitude spectrograms
#             mag = X.abs()
#             mag_r = Xr.abs()

#             # 1) Spectral Convergence (stabilizes training)
#             sc = torch.norm(mag - mag_r, p="fro") / (torch.norm(mag, p="fro") + 1e-8)
#             sc_losses.append(sc)

#             # 2) Log-magnitude L1 loss (prevents large spikes)
#             mag_l1 = F.l1_loss(torch.log1p(mag), torch.log1p(mag_r))
#             mag_losses.append(mag_l1)

#         # Weighted sum
#         loss = (
#             sum(sc_losses) / len(sc_losses) +
#             sum(mag_losses) / len(mag_losses)
#         )

#         return loss


In [10]:
# num_epochs = 50

# for epoch in range(num_epochs):
#     total_loss = 0.0
#     total_recon = 0.0
#     total_vq = 0.0

#     for batch_idx, batch_waveform in enumerate(loader):

#         # Move batch to device
#         x = batch_waveform.to(device)    # <-- you used batch_waveform but later referenced x
#         optimizer.zero_grad()
#         model.train()

#         # Forward pass
#         x_rec, vq_loss, indices = model(x)

#         # Reconstruction loss
#         recon_loss = F.l1_loss(x_rec, x)

#         # Total loss
#         loss = recon_loss + vq_loss

#         # Backprop
#         loss.backward()
#         optimizer.step()

#         # Track for epoch summary
#         total_loss += loss.item()
#         total_recon += recon_loss.item()
#         total_vq += vq_loss.item()

#         # Print progress
#         if batch_idx % 10 == 0:
#             print(
#                 f"Epoch [{epoch+1}/{num_epochs}], "
#                 f"Batch [{batch_idx}/{len(loader)}], "
#                 f"Loss: {loss.item():.4f}, "
#                 f"Recon: {recon_loss.item():.4f}, "
#                 f"VQ: {vq_loss.item():.4f}"
#             )

#     # End-of-epoch summary
#     print(
#         f"\nEpoch [{epoch+1}/{num_epochs}] SUMMARY:\n"
#         f"  Avg Loss: {total_loss / len(loader):.4f}\n"
#         f"  Avg Recon: {total_recon / len(loader):.4f}\n"
#         f"  Avg VQ: {total_vq / len(loader):.4f}\n"
#     )


num_epochs = 20

for epoch in range(num_epochs):
    total_loss = 0.0
    total_l1 = 0.0
    total_stft = 0.0
    total_vq = 0.0

    for batch_idx, batch_waveform in enumerate(loader):

        x = batch_waveform.to(device)  # (B,1,T)
        optimizer.zero_grad()
        model.train()

        # --------------------
        # Forward pass
        # --------------------
        x_rec, vq_loss, indices = model(x)

        # Waveform L1 loss
        l1_loss = F.l1_loss(x_rec, x)

        # Multi-scale STFT loss
        stft_loss = multiscale_stft_loss(x, x_rec)

        # Combined loss
        loss = l1_loss + 0.1 * stft_loss + vq_loss

        # --------------------
        # Backprop
        # --------------------
        loss.backward()
        optimizer.step()

        # Track for epoch
        total_loss += loss.item()
        total_l1 += l1_loss.item()
        total_stft += stft_loss.item()
        total_vq += vq_loss.item()

        if batch_idx % 10 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}] "
                f"Batch [{batch_idx}/{len(loader)}] | "
                f"Loss: {loss.item():.4f} | "
                f"L1: {l1_loss.item():.4f} | "
                f"STFT: {stft_loss.item():.4f} | "
                f"VQ: {vq_loss.item():.4f}"
            )

    # End-of-epoch summary
    print(
        f"\nEPOCH {epoch+1} SUMMARY:\n"
        f"  Avg Loss : {total_loss/len(loader):.4f}\n"
        f"  Avg L1   : {total_l1/len(loader):.4f}\n"
        f"  Avg STFT : {total_stft/len(loader):.4f}\n"
        f"  Avg VQ   : {total_vq/len(loader):.4f}\n"
    )


Epoch [1/20] Batch [0/13] | Loss: 1.1143 | L1: 0.5181 | STFT: 5.9614 | VQ: 0.0000
Epoch [1/20] Batch [10/13] | Loss: 0.9586 | L1: 0.3955 | STFT: 5.5985 | VQ: 0.0033

EPOCH 1 SUMMARY:
  Avg Loss : 1.0040
  Avg L1   : 0.4449
  Avg STFT : 5.5756
  Avg VQ   : 0.0016

Epoch [2/20] Batch [0/13] | Loss: 0.8642 | L1: 0.3400 | STFT: 5.1398 | VQ: 0.0102
Epoch [2/20] Batch [10/13] | Loss: 0.7086 | L1: 0.2063 | STFT: 4.0496 | VQ: 0.0973

EPOCH 2 SUMMARY:
  Avg Loss : 0.7130
  Avg L1   : 0.2209
  Avg STFT : 4.3223
  Avg VQ   : 0.0600

Epoch [3/20] Batch [0/13] | Loss: 0.7198 | L1: 0.1994 | STFT: 4.8607 | VQ: 0.0343
Epoch [3/20] Batch [10/13] | Loss: 0.6047 | L1: 0.1627 | STFT: 4.3629 | VQ: 0.0057

EPOCH 3 SUMMARY:
  Avg Loss : 0.5437
  Avg L1   : 0.1528
  Avg STFT : 3.8292
  Avg VQ   : 0.0081

Epoch [4/20] Batch [0/13] | Loss: 0.4791 | L1: 0.1229 | STFT: 3.4946 | VQ: 0.0067
Epoch [4/20] Batch [10/13] | Loss: 0.3771 | L1: 0.1077 | STFT: 2.6318 | VQ: 0.0062

EPOCH 4 SUMMARY:
  Avg Loss : 0.4719
  Avg

In [12]:

# # import torch
# # import torch.nn.functional as F
# # from torch.utils.data import DataLoader, TensorDataset

# # BATCH_SIZE = 2
# # NUM_SAMPLES = 10
# # T = 1024
# # dummy_audio_data = torch.randn(NUM_SAMPLES, 1, T)  # (10, 1, 1024)

# # # Wrap in TensorDataset and DataLoader
# # dataset = TensorDataset(dummy_audio_data)
# # dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
# # model.eval()
# # reconstructed_audio = reconstruct_from_indices(model, indices)  # (B,1,T)

# # print(f"Original Audio Shape: {dummy_audio.shape}")
# # print(f"Reconstructed Audio Shape: {reconstructed_audio.shape}")

# # # Optional: convert to numpy to listen or visualize
# # reconstructed_audio_np = reconstructed_audio.cpu().numpy()




import torch
import torchaudio
import os

# Make sure your model is already loaded and on the correct device
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Path to your audio file (update with your Kaggle dataset path)
audio_path = "/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original/blues/blues.00000.wav"

# Load audio
waveform, sr = torchaudio.load(audio_path)  # shape: (channels, T_orig)

# Convert to mono and batch dimension: (1, 1, T)
waveform = waveform.mean(dim=0, keepdim=True).unsqueeze(0).to(device)

# Forward pass through VQ-VAE
with torch.no_grad():
    reconstructed, _, indices = model(waveform)

# Save original and reconstructed audio
os.makedirs("/kaggle/working/output_audio", exist_ok=True)

# Original
# torchaudio.save("/kaggle/working/output_audio/original.wav", waveform.squeeze(0).cpu(), sr)

# Reconstructed
torchaudio.save("/kaggle/working/output_audio/reconstructed.wav", reconstructed.squeeze(0).cpu(), sr)

print("Saved original and reconstructed audio in /kaggle/working/output_audio/")
print(f"Original shape: {waveform.shape}, Reconstructed shape: {reconstructed.shape}")


# import torchaudio
# import torch
# import numpy as np

# # pick a file from your dataset
# AUDIO_PATH = "/kaggle/input/gtzan-dataset-music-genre-classification/Data/genres_original/blues/blues.00001.wav"

# model.eval()

# # ----------------------------------------------------
# # 1. Load audio from dataset
# # ----------------------------------------------------
# waveform, sr = torchaudio.load(AUDIO_PATH)     # shape: (C, T)
# waveform = waveform.mean(dim=0, keepdim=True)  # convert to mono → (1, T)
# waveform = waveform.unsqueeze(0)               # → (1, 1, T)
# waveform = waveform.to(device)

# print("Loaded waveform:", waveform.shape)

# # ----------------------------------------------------
# # 2. Convert audio → tokens
# # ----------------------------------------------------
# tokens = tokenize_audio(model, waveform)    # numpy array (1, T_down)
# print("Token shape:", tokens.shape)

# # ----------------------------------------------------
# # 3. Reconstruct audio using ONLY indices
# # ----------------------------------------------------
# reconstructed = reconstruct_from_indices(model, tokens)
# reconstructed = reconstructed.cpu().detach()

# print("Reconstructed shape:", reconstructed.shape)

# # ----------------------------------------------------
# # 4. Save both audios
# # ----------------------------------------------------
# # torchaudio.save("original.wav", waveform[0].cpu(), sample_rate=sr)
# torchaudio.save("reconstructed2.wav", reconstructed[0], sample_rate=sr)

# print("Saved original.wav and reconstructed.wav!")


Saved original and reconstructed audio in /kaggle/working/output_audio/
Original shape: torch.Size([1, 1, 661794]), Reconstructed shape: torch.Size([1, 1, 661792])


In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("andradaolteanu/gtzan-dataset-music-genre-classification")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/gtzan-dataset-music-genre-classification
