In [1]:
import os
import glob
from shutil import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import librosa
import soundfile as sf
from scipy.fftpack import dct, idct
import copy

In [2]:
# --------- MDCT/IMDCT (approximate, using DCT-IV) ---------
def mdct(x, n_fft):
    # Pad to multiple of n_fft
    pad = (n_fft - len(x) % n_fft) % n_fft
    x = np.pad(x, (0, pad), mode='constant')
    frames = librosa.util.frame(x, frame_length=n_fft, hop_length=n_fft//2).T
    # DCT-IV
    return dct(frames, type=2, axis=1, norm='ortho')

def imdct(X, n_fft):
    # Inverse DCT-IV
    frames = idct(X, type=2, axis=1, norm='ortho')
    # Overlap-add
    hop = n_fft // 2
    out = np.zeros((frames.shape[0] * hop + hop,))
    for i, frame in enumerate(frames):
        out[i*hop:i*hop+n_fft] += frame
    return out

# Add this class
class SelfAttention2d(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.query = nn.Conv2d(in_channels, in_channels, 1)
        self.key = nn.Conv2d(in_channels, in_channels, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, C, H, W = x.shape
        q = self.query(x).reshape(B, C, -1)
        k = self.key(x).reshape(B, C, -1)
        v = self.value(x).reshape(B, C, -1)
        attn = self.softmax(torch.bmm(q.transpose(1,2), k) / (C ** 0.5))
        out = torch.bmm(attn, v.transpose(1,2)).transpose(1,2)
        out = out.reshape(B, C, H, W)
        return out + x

# --------- Dataset with slicing and normalization ---------
class AudioMDCTDataset(Dataset):
    def __init__(self, file_list, rate=10000, feats=256, duration=3.3, total_seconds=26, hop_size=1, normalize=True):
        self.rate = rate
        self.feats = feats
        self.duration = duration
        self.segments = []
        for file in file_list:
            for offset in range(0, total_seconds, hop_size):
                self.segments.append((file, offset))
        self.normalize = normalize
        self.mean = 0.0
        self.std = 1.0
        if normalize:
            self._compute_stats()

    def _compute_stats(self):
        specs = []
        for file, offset in self.segments[:min(100, len(self.segments))]:  # sample for stats
            audio, sr = librosa.load(file, sr=self.rate, offset=offset, duration=self.duration)
            audio_fill = np.zeros(int(self.rate * self.duration), dtype=np.float32)
            audio_fill[:len(audio)] = audio
            spec = mdct(audio_fill, self.feats)
            specs.append(spec)
        specs = np.stack(specs)
        self.mean = specs.mean()
        self.std = specs.std() + 1e-8

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        file, offset = self.segments[idx]
        audio, sr = librosa.load(file, sr=self.rate, offset=offset, duration=self.duration)
        audio_fill = np.zeros(int(self.rate * self.duration), dtype=np.float32)
        audio_fill[:len(audio)] = audio
        spec = mdct(audio_fill, self.feats)
        if self.normalize:
            spec = (spec - self.mean) / self.std
        spec = np.expand_dims(spec, axis=0)  # (1, feats, time)
        return torch.tensor(spec, dtype=torch.float32)

# --------- Sinusoidal Embedding ---------
class SinusoidalEmbedding(nn.Module):
    def __init__(self, embedding_dims=32, min_freq=1.0, max_freq=1000.0):
        super().__init__()
        self.embedding_dims = embedding_dims
        self.min_freq = min_freq
        self.max_freq = max_freq

    def forward(self, x):
        device = x.device
        frequencies = torch.exp(
            torch.linspace(np.log(self.min_freq), np.log(self.max_freq), self.embedding_dims // 2, device=device)
        )
        angular_speeds = 2.0 * np.pi * frequencies
        x = x.unsqueeze(-1)
        embeddings = torch.cat([torch.sin(angular_speeds * x), torch.cos(angular_speeds * x)], dim=-1)
        return embeddings

# --------- U-Net Blocks ---------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.same_channels = in_channels == out_channels
        self.bn = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if not self.same_channels:
            self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        residual = x if self.same_channels else self.res_conv(x)
        x = self.bn(x)
        x = F.silu(self.conv1(x))
        x = self.conv2(x)
        return x + residual

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, block_depth):
        super().__init__()
        self.blocks = nn.Sequential(*[
            ResidualBlock(in_channels if i == 0 else out_channels, out_channels)
            for i in range(block_depth)
        ])
        self.pool = nn.AvgPool2d(2)

    def forward(self, x, skips):
        x = self.blocks(x)
        skips.append(x)
        x = self.pool(x)
        return x

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, block_depth):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.blocks = nn.Sequential(*[
            ResidualBlock(in_channels if i == 0 else out_channels, out_channels)
            for i in range(block_depth)
        ])

    def forward(self, x, skips):
        x = self.upsample(x)
        skip = skips.pop()
        # Crop skip to match x's spatial size
        if skip.shape[2:] != x.shape[2:]:
            min_h = min(skip.shape[2], x.shape[2])
            min_w = min(skip.shape[3], x.shape[3])
            skip = skip[:, :, :min_h, :min_w]
            x = x[:, :, :min_h, :min_w]
        x = torch.cat([x, skip], dim=1)
        x = self.blocks(x)
        return x

class UNet(nn.Module):
    def __init__(self, widths, block_depth, in_channels=1, embedding_dims=32, dim1=256, dim2=128):
        super().__init__()
        self.embedding = SinusoidalEmbedding(embedding_dims)
        self.input_conv = nn.Conv2d(in_channels + embedding_dims, widths[0], kernel_size=1)
        self.downs = nn.ModuleList()
        for i in range(len(widths) - 1):
            self.downs.append(DownBlock(widths[i], widths[i+1], block_depth))
        self.mid = nn.Sequential(*[ResidualBlock(widths[-1], widths[-1]) for _ in range(block_depth)])
        self.attention = SelfAttention2d(widths[-1])
        self.ups = nn.ModuleList()
        for i in range(len(widths) - 2, -1, -1):
            self.ups.append(UpBlock(widths[i+1]*2, widths[i], block_depth))
        self.output_conv = nn.Conv2d(widths[0], in_channels, kernel_size=1)

    def forward(self, x, noise_var):
        batch, _, h, w = x.shape
        # Ensure noise_var is (batch,)
        if noise_var.dim() > 1:
            noise_var = noise_var.view(batch)
        e = self.embedding(noise_var.to(x.device))
        e = e.unsqueeze(-1).unsqueeze(-1)  # (batch, embedding_dims, 1, 1)
        e = e.expand(-1, -1, h, w)
        x = torch.cat([x, e], dim=1)
        x = self.input_conv(x)
        skips = []
        for down in self.downs:
            x = down(x, skips)
        x = self.mid(x)
        x = self.attention(x)
        for up in self.ups:
            x = up(x, skips)
        x = self.output_conv(x)
        return x

In [3]:
# --------- Losses ---------
def spectral_norm_loss(pred, real):
    norm_real = torch.norm(real, dim=(2,3)) + 1e-6
    norm_pred = torch.norm(pred, dim=(2,3)) + 1e-6
    return torch.mean(torch.abs(norm_real - norm_pred) / norm_real)

def time_derivative_loss(pred, real, window=1):
    real_deriv = real[:, :, :-window, :] - real[:, :, window:, :]
    pred_deriv = pred[:, :, :-window, :] - pred[:, :, window:, :]
    return F.mse_loss(real_deriv, pred_deriv)

# --------- Diffusion Schedule ---------
def diffusion_schedule(diffusion_times, min_signal_rate=0.02, max_signal_rate=0.95):
    start_angle = torch.acos(torch.tensor(max_signal_rate))
    end_angle = torch.acos(torch.tensor(min_signal_rate))
    diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
    signal_rates = torch.cos(diffusion_angles)
    noise_rates = torch.sin(diffusion_angles)
    return noise_rates, signal_rates

# --------- Training Loop ---------
def train(model, dataloader, optimizer, device, ema_model=None, ema_decay=0.999, epochs=1, mean=0.0, std=1.0, spec_mod=0.0, dx_mod=0.0):
    model.train()
    loss_fn = nn.MSELoss()
    for epoch in range(epochs):
        for batch in tqdm(dataloader):
            batch = batch.to(device)
            noises = torch.randn_like(batch)
            batch_size = batch.size(0)
            diffusion_times = torch.rand(batch_size, 1, 1, 1, device=device)
            noise_rates, signal_rates = diffusion_schedule(diffusion_times)
            noisy = signal_rates * batch + noise_rates * noises

            optimizer.zero_grad()
            # pred_noises = model(noisy, noise_rates.squeeze(-1).squeeze(-1))
            pred_noises = model(noisy, noise_rates.flatten())
            # Crop pred and noises to the same shape
            min_h = min(pred_noises.shape[2], noises.shape[2])
            min_w = min(pred_noises.shape[3], noises.shape[3])
            pred_noises = pred_noises[:, :, :min_h, :min_w]
            noises = noises[:, :, :min_h, :min_w]
            batch = batch[:, :, :min_h, :min_w]

            loss = loss_fn(pred_noises, noises)
            if spec_mod > 0:
                loss += spec_mod * spectral_norm_loss(pred_noises, noises)
            if dx_mod > 0:
                loss += dx_mod * time_derivative_loss(pred_noises, noises)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                if ema_model is not None:
                    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                        ema_param.data.mul_(ema_decay).add_(param.data, alpha=1 - ema_decay)
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# --------- Reverse Diffusion (DDIM) ---------
def reverse_diffusion(model, initial_noise, diffusion_steps, device):
    model.eval()
    num_examples = initial_noise.shape[0]
    step_size = 1.0 / diffusion_steps
    next_noisy_data = initial_noise
    for step in tqdm(range(diffusion_steps)):
        noisy_data = next_noisy_data
        diffusion_times = torch.ones((num_examples, 1), device=device) - step * step_size
        noise_rates, signal_rates = diffusion_schedule(diffusion_times)
        pred_noises = model(noisy_data, noise_rates)
        # --- Crop all tensors to the minimum shape ---
        min_h = min(noisy_data.shape[2], pred_noises.shape[2])
        min_w = min(noisy_data.shape[3], pred_noises.shape[3])
        noisy_data = noisy_data[:, :, :min_h, :min_w]
        pred_noises = pred_noises[:, :, :min_h, :min_w]
        noise_rates_ = noise_rates.view(-1,1,1,1).expand(-1,1,min_h,min_w)
        signal_rates_ = signal_rates.view(-1,1,1,1).expand(-1,1,min_h,min_w)
        pred_data = (noisy_data - noise_rates_ * pred_noises) / signal_rates_
        next_diffusion_times = diffusion_times - step_size
        next_noise_rates, next_signal_rates = diffusion_schedule(next_diffusion_times)
        next_noise_rates_ = next_noise_rates.view(-1,1,1,1).expand(-1,1,min_h,min_w)
        next_signal_rates_ = next_signal_rates.view(-1,1,1,1).expand(-1,1,min_h,min_w)
        next_noisy_data = next_signal_rates_ * pred_data + next_noise_rates_ * pred_noises
    return pred_data

# --------- Generation ---------
def generate(model, shape, device, diffusion_steps=1000, mean=0.0, std=1.0):
    with torch.no_grad():
        initial_noise = torch.randn(shape).to(device)
        generated_data = reverse_diffusion(model, initial_noise, diffusion_steps, device)
        # Denormalize
        generated_data = generated_data * std + mean
        return generated_data.cpu()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Using device: {device}")

genres = [
    "blues", "classical", "country", "disco", "hiphop",
    "jazz", "metal", "pop", "reggae", "rock"
]

samples_per_genre = 10
results = {}

for genre in genres:
    print(f"Training and generating for genre: {genre}")
    music_files = glob.glob(f"../genres/{genre}/*.au")
    if len(music_files) == 0:
        continue

    dataset = AudioMDCTDataset(
        music_files, rate=10000, feats=256, duration=30.0, total_seconds=30, hop_size=1, normalize=True
    )
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0)

    # Get shape and normalization stats
    for test_batch in dataloader:
        shape = test_batch.shape
        break
    mean = dataset.mean
    std = dataset.std

    model = UNet(widths=[128, 128, 128, 128], block_depth=2,
                    in_channels=1, dim1=shape[2], dim2=shape[3]).to(device)
    ema_model = copy.deepcopy(model)
    ema_decay = 0.999
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    num_total_examples = len(dataset) // shape[0]
    for epoch in range(10):  # You can increase epochs as needed
        print(f"Epoch {epoch+1}/10")
        train(model, dataloader, optimizer, device, ema_model=ema_model, ema_decay=ema_decay, epochs=1, mean=mean, std=std, spec_mod=0.1, dx_mod=0.1)

    print(f"Generating {samples_per_genre} samples for genre: {genre}")
    samples = generate(ema_model, (samples_per_genre, 1, shape[2], shape[3]), device, diffusion_steps=1000, mean=mean, std=std)
    results[genre] = samples

    # --- Save generated samples as audio files ---
    os.makedirs(f"generated/{genre}", exist_ok=True)
    for i, spec in enumerate(samples):
        spec_np = spec.squeeze().cpu().numpy()
        audio = imdct(spec_np, 256)
        audio = audio[:int(30.0 * 10000)]  # Ensure 30 seconds
        sf.write(f"generated/{genre}/sample_{i+1}.wav", audio, 10000)

Using device: cuda
Training and generating for genre: blues
Epoch 1/10


100%|██████████| 163/163 [29:10<00:00, 10.74s/it]


Epoch 1, Loss: 0.1469702124595642
Epoch 2/10


100%|██████████| 163/163 [29:07<00:00, 10.72s/it]


Epoch 1, Loss: 0.08760330080986023
Epoch 3/10


100%|██████████| 163/163 [29:03<00:00, 10.69s/it]


Epoch 1, Loss: 0.11753316223621368
Epoch 4/10


100%|██████████| 163/163 [29:03<00:00, 10.70s/it]


Epoch 1, Loss: 0.14250564575195312
Epoch 5/10


100%|██████████| 163/163 [28:48<00:00, 10.61s/it]


Epoch 1, Loss: 0.09125635027885437
Epoch 6/10


100%|██████████| 163/163 [28:48<00:00, 10.60s/it]


Epoch 1, Loss: 0.07614340633153915
Epoch 7/10


100%|██████████| 163/163 [28:53<00:00, 10.64s/it]


Epoch 1, Loss: 0.15747524797916412
Epoch 8/10


100%|██████████| 163/163 [28:55<00:00, 10.65s/it]


Epoch 1, Loss: 0.11004752665758133
Epoch 9/10


100%|██████████| 163/163 [28:55<00:00, 10.65s/it]


Epoch 1, Loss: 0.11295919120311737
Epoch 10/10


100%|██████████| 163/163 [28:59<00:00, 10.67s/it]


Epoch 1, Loss: 0.11356737464666367
Generating 10 samples for genre: blues


100%|██████████| 1000/1000 [04:40<00:00,  3.57it/s]


Training and generating for genre: classical
Epoch 1/10


100%|██████████| 163/163 [29:08<00:00, 10.72s/it]


Epoch 1, Loss: 0.16803498566150665
Epoch 2/10


100%|██████████| 163/163 [28:00<00:00, 10.31s/it]


Epoch 1, Loss: 0.2830210030078888
Epoch 3/10


100%|██████████| 163/163 [27:51<00:00, 10.25s/it]


Epoch 1, Loss: 0.20761965215206146
Epoch 4/10


100%|██████████| 163/163 [27:54<00:00, 10.27s/it]


Epoch 1, Loss: 0.13871783018112183
Epoch 5/10


100%|██████████| 163/163 [27:52<00:00, 10.26s/it]


Epoch 1, Loss: 0.15734903514385223
Epoch 6/10


100%|██████████| 163/163 [27:51<00:00, 10.26s/it]


Epoch 1, Loss: 0.10454849898815155
Epoch 7/10


100%|██████████| 163/163 [27:50<00:00, 10.25s/it]


Epoch 1, Loss: 0.1480257511138916
Epoch 8/10


100%|██████████| 163/163 [27:51<00:00, 10.26s/it]


Epoch 1, Loss: 0.15202981233596802
Epoch 9/10


100%|██████████| 163/163 [27:50<00:00, 10.25s/it]


Epoch 1, Loss: 0.12792852520942688
Epoch 10/10


100%|██████████| 163/163 [27:53<00:00, 10.27s/it]


Epoch 1, Loss: 0.11239919066429138
Generating 10 samples for genre: classical


100%|██████████| 1000/1000 [04:32<00:00,  3.67it/s]


Training and generating for genre: country
Epoch 1/10


100%|██████████| 163/163 [27:51<00:00, 10.25s/it]


Epoch 1, Loss: 0.15719641745090485
Epoch 2/10


100%|██████████| 163/163 [27:53<00:00, 10.27s/it]


Epoch 1, Loss: 0.1147167980670929
Epoch 3/10


  1%|          | 2/163 [00:20<27:59, 10.43s/it]


KeyboardInterrupt: 