# CMGAN

### Imports

In [1]:
import librosa
import numpy as np
import torch
from torch.utils.data import Dataset
from pathlib import Path
import torch.nn as nn
import torch.nn.functional as F

### Device

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### GPU State

In [3]:
import torch
torch.cuda.empty_cache()
# In terminal
!nvidia-smi


Fri Jul 11 20:14:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 561.03                 Driver Version: 561.03         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   49C    P0             16W /   95W |       0MiB /   8188MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

### Conformer Architecture

In [4]:
class ConformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4, conv_kernel_size=31):
        super().__init__()
        self.ffn1 = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(dim * 4, dim)
        )

        self.mha = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

        self.conv = nn.Sequential(
            nn.Conv1d(dim, dim * 2, kernel_size=1),
            nn.GLU(dim=1),
            nn.Conv1d(dim, dim, kernel_size=conv_kernel_size, padding=conv_kernel_size // 2, groups=dim),
            nn.BatchNorm1d(dim),
            nn.SiLU(),
            nn.Conv1d(dim, dim, kernel_size=1),
        )

        self.ffn2 = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(dim * 4, dim)
        )

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):  # x: [B, T, dim]
        x = x + 0.5 * self.ffn1(x)
        attn_out, _ = self.mha(x, x, x)
        x = x + attn_out

        conv_in = x.transpose(1, 2)  # [B, dim, T]
        conv_out = self.conv(conv_in).transpose(1, 2)  # [B, T, dim]
        x = x + conv_out

        x = x + 0.5 * self.ffn2(x)
        return self.norm(x)

### Generator Architecture

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Dilated DenseNet Block ---
class DilatedDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate=16, num_layers=4, out_channels=None):
        super().__init__()
        self.num_layers = num_layers
        self.growth_rate = growth_rate

        self.layers = nn.ModuleList()
        channels = in_channels
        for i in range(num_layers):
            self.layers.append(
                nn.Conv2d(channels, growth_rate, kernel_size=3, dilation=2**i, padding=2**i)
            )
            channels += growth_rate

        self.output_conv = nn.Conv2d(channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        features = [x]
        for layer in self.layers:
            x_cat = torch.cat(features, dim=1)
            out = self.relu(layer(x_cat))
            features.append(out)
        x_cat = torch.cat(features, dim=1)
        return self.output_conv(x_cat)

# --- CMGAN Generator ---
class CMGANGenerator(nn.Module):
    def __init__(self, n_fft=512, hop_length=128, conformer_dim=64, num_blocks=4):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.register_buffer("window", torch.hann_window(self.n_fft), persistent=False)

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),              # Pre-projection
            nn.ReLU(),
            DilatedDenseBlock(in_channels=64, out_channels=128),     # Dilated DenseNet
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),           # Refinement
            nn.ReLU(),
            nn.Conv2d(128, conformer_dim, kernel_size=1)             # Output projection
        )

        # Conformer blocks
        self.conformers = nn.Sequential(
            *[ConformerBlock(dim=conformer_dim) for _ in range(num_blocks)]
        )

        # Mask decoder
        self.mask_decoder = nn.Sequential(
            DilatedDenseBlock(in_channels=conformer_dim, out_channels=conformer_dim),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(conformer_dim, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Complex decoder
        self.complex_decoder = nn.Sequential(
            DilatedDenseBlock(in_channels=conformer_dim, out_channels=conformer_dim),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(conformer_dim, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 2, kernel_size=1)  # real & imag
        )

    def forward(self, x):
        B, _, T = x.shape

        # STFT
        stft = torch.stft(x.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, return_complex=True)

        mag = stft.abs()      # [B, F, T']
        real = stft.real
        imag = stft.imag
        phase = stft.angle()

        # Stack input features [B, 3, F, T']
        input_feat = torch.stack([mag, real, imag], dim=1)

        # Encode
        x_encoded = self.encoder(input_feat)              # [B, C, F, T']
        x_pooled = torch.mean(x_encoded, dim=2)           # [B, C, T']
        x_seq = x_pooled.permute(0, 2, 1)                 # [B, T', C]

        # Conformer blocks
        x_seq = self.conformers(x_seq)                    # [B, T', C]

        # Reshape for decoding
        x = x_seq.permute(0, 2, 1).unsqueeze(2)           # [B, C, 1, T']
        x = x.expand(-1, -1, mag.shape[1], -1)            # [B, C, F, T']

        # Decoders
        mask = self.mask_decoder(x)                       # [B, 1, F, T"]
        complex_pred = self.complex_decoder(x)            # [B, 2, F, T"]

        # Resize to match mag shape
        if mask.shape[-1] != mag.shape[-1]:
            mask = F.interpolate(mask, size=mag.shape[-1], mode='bilinear', align_corners=False)
        if mask.shape[-2] != mag.shape[-2]:
            mask = F.interpolate(mask, size=mag.shape[-2:], mode='bilinear', align_corners=False)
        if complex_pred.shape[-2:] != mag.shape[-2:]:
            complex_pred = F.interpolate(complex_pred, size=mag.shape[-2:], mode='bilinear', align_corners=False)

        # Masked magnitude + phase reconstruction
        enhanced_mag = mag.unsqueeze(1) * mask            # [B, 1, F, T']
        masked_real = enhanced_mag.squeeze(1) * torch.cos(phase)
        masked_imag = enhanced_mag.squeeze(1) * torch.sin(phase)
        masked_complex = torch.complex(masked_real, masked_imag)

        # iSTFT
        enhanced_waveform = torch.istft(masked_complex, n_fft=self.n_fft, hop_length=self.hop_length, length=T, window=self.window)

        return {
            'enhanced_waveform': enhanced_waveform.unsqueeze(1),
            'predicted_complex': complex_pred,
            'mask': mask
        }

### Metric discriminator

In [6]:
import torch
import torch.nn as nn

class MetricDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d((1, 1)),  # → (B, 128, 1, 1)
            nn.Flatten(),                  # → (B, 128)
            nn.Linear(128, 1),             # → (B, 1)
            nn.Sigmoid()
        )

    def forward(self, clean_mag, enhanced_mag):
        x = torch.stack([clean_mag, enhanced_mag], dim=1)  # → (B, 2, F, T)
        return self.net(x)  # → (B, 1)


### TF Loss (Ltf)

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TFLoss(nn.Module):
    def __init__(self, n_fft=512, hop_length=128, alpha=0.7):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.alpha = alpha  # α in the paper

    def forward(self, enhanced, clean):
        # Create Hann window on the same device as input
        window = torch.hann_window(self.n_fft).to(enhanced.device)

        # STFT with window
        est_stft = torch.stft(enhanced.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, 
                              return_complex=True, window=window)
        clean_stft = torch.stft(clean.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, 
                                return_complex=True, window=window)

        # Magnitude loss (L_Mag)
        mag_enh = est_stft.abs()
        mag_clean = clean_stft.abs()
        l_mag = F.mse_loss(mag_enh, mag_clean)

        # Complex loss (L_RI)
        l_ri = F.mse_loss(est_stft.real, clean_stft.real) + F.mse_loss(est_stft.imag, clean_stft.imag)

        # Final TF loss
        l_tf = self.alpha * l_mag + (1 - self.alpha) * l_ri
        return l_tf


### GAN Loss (Lgan)

In [8]:
class GANLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, D, clean_mag, enhanced_mag, qpesq_score=None):
        pred_fake = D(clean_mag, enhanced_mag)  # → [B, 1]
        pred_real = D(clean_mag, clean_mag)     # → [B, 1]

        # LGAN: Generator loss (wants D to output 1 for fake)
        g_loss = F.mse_loss(pred_fake, torch.ones_like(pred_fake))

        # LD: Discriminator loss
        real_loss = F.mse_loss(pred_real, torch.ones_like(pred_real))

        if qpesq_score is None:
            qpesq_score = torch.zeros_like(pred_fake)  # → match shape [B, 1]

        fake_loss = F.mse_loss(pred_fake, qpesq_score)
        d_loss = real_loss + fake_loss

        return g_loss, d_loss



### Time Domain Loss (Ltime)

In [9]:
class TimeDomainLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.L1Loss()

    def forward(self, enhanced, clean):
        return self.l1(enhanced, clean)


CMGAN Loss (Lg)

In [10]:
class CMGANLoss(nn.Module):
    def __init__(self, gamma1=1.0, gamma2=1.0, gamma3=1.0, alpha=0.7, n_fft=512, hop_length=128):
        super().__init__()
        self.tf_loss = TFLoss(alpha=alpha, n_fft=n_fft, hop_length=hop_length)
        self.gan_loss = GANLoss()
        self.time_loss = nn.L1Loss()
        self.gamma1 = gamma1
        self.gamma2 = gamma2
        self.gamma3 = gamma3
        self.n_fft = n_fft
        self.hop_length = hop_length

    def forward(self, enhanced, clean, D, qpesq_score=None):
        # Time-frequency loss
        l_tf = self.tf_loss(enhanced, clean)

        # Time-domain loss
        l_time = self.time_loss(enhanced, clean)

        # GAN loss
        window = torch.hann_window(self.n_fft).to(enhanced.device)
        est_stft = torch.stft(enhanced.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, return_complex=True, window=window)
        clean_stft = torch.stft(clean.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, return_complex=True, window=window)

        est_mag = est_stft.abs()
        clean_mag = clean_stft.abs()

        g_loss, _ = self.gan_loss(D, clean_mag, est_mag, qpesq_score)

        # Final combined generator loss
        l_gen = self.gamma1 * l_tf + self.gamma2 * g_loss + self.gamma3 * l_time
        return l_gen


Discriminator Loss (Ld)

In [11]:
import torch
import torch.nn.functional as F

def compute_discriminator_loss(D, enhanced, clean, qpesq_score=None, n_fft=512, hop_length=128):
    # Create Hann window on the same device
    window = torch.hann_window(n_fft).to(enhanced.device)

    # Compute STFTs
    est_stft = torch.stft(enhanced.squeeze(1), n_fft=n_fft, hop_length=hop_length,
                          return_complex=True, window=window)
    clean_stft = torch.stft(clean.squeeze(1), n_fft=n_fft, hop_length=hop_length,
                            return_complex=True, window=window)

    # Get magnitude spectrograms
    est_mag = est_stft.abs()
    clean_mag = clean_stft.abs()

    # Discriminator outputs
    pred_real = D(clean_mag, clean_mag)      # shape: [B, 1]
    pred_fake = D(clean_mag, est_mag)        # shape: [B, 1]

    # Targets must match shape [B, 1]
    real_target = torch.ones_like(pred_real)
    fake_target = qpesq_score if qpesq_score is not None else torch.zeros_like(pred_fake)

    # Losses
    real_loss = F.mse_loss(pred_real, real_target)
    fake_loss = F.mse_loss(pred_fake, fake_target)

    # Final loss
    d_loss = real_loss + fake_loss
    return d_loss



In [12]:
g_loss_fn = CMGANLoss(gamma1=1.0, gamma2=1.0, gamma3=1.0).cuda()

In [13]:
print(device)

cuda


In [14]:
def train(generator, discriminator, dataloader, g_loss_fn, g_optimizer, d_optimizer, epochs=5):
    for epoch in range(epochs):
        generator.train()
        discriminator.train()

        for batch in dataloader:
            noisy = batch['noisy'].cuda()
            clean = batch['clean'].cuda()

            # Generator forward
            outputs = generator(noisy)
            enhanced = outputs['enhanced_waveform']

            if epoch % 4 ==0:
                # Discriminator update
                d_optimizer.zero_grad()
                d_loss = compute_discriminator_loss(discriminator, enhanced.detach(), clean)
                d_loss.backward()
                d_optimizer.step()

            # Generator update
            g_optimizer.zero_grad()
            g_loss = g_loss_fn(enhanced, clean, discriminator)
            g_loss.backward()
            g_optimizer.step()

        print(f"[Epoch {epoch+1}] G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")


In [15]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import librosa

class ChunkedAudioDataset(Dataset):
    def __init__(self, noisy_dir, clean_dir, chunk_size=16000, sr=16000, max_files=500):
        self.chunk_size = chunk_size
        self.sr = sr
        self.pairs = []

        noisy_files = sorted(Path(noisy_dir).glob("*.wav"))[:max_files]

        for file in noisy_files:
            filename = file.name
            clean_file = Path(clean_dir) / filename
            if not clean_file.exists():
                continue

            duration = librosa.get_duration(path=str(file))
            total_samples = int(duration * sr) // chunk_size * chunk_size
            num_chunks = total_samples // chunk_size

            for i in range(num_chunks):
                self.pairs.append((file, clean_file, i * chunk_size))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        noisy_path, clean_path, offset = self.pairs[idx]
        noisy, _ = librosa.load(noisy_path, sr=self.sr, offset=offset / self.sr, duration=self.chunk_size / self.sr)
        clean, _ = librosa.load(clean_path, sr=self.sr, offset=offset / self.sr, duration=self.chunk_size / self.sr)

        noisy = torch.tensor(noisy, dtype=torch.float32).unsqueeze(0)
        clean = torch.tensor(clean, dtype=torch.float32).unsqueeze(0)

        return {'noisy': noisy, 'clean': clean}

dataset = ChunkedAudioDataset(
    noisy_dir=r"E:\noisy_sound",
    clean_dir=r"D:\vs_code_python\Sperctromorph_GANs\data\clean_testset_wav",
    chunk_size=16000
)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

generator = CMGANGenerator().cuda()
discriminator = MetricDiscriminator().cuda()
g_loss_fn = CMGANLoss(gamma1=1.0, gamma2=1.0, gamma3=1.0).cuda()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

#train(generator, discriminator, dataloader, g_loss_fn, g_optimizer, d_optimizer, epochs=50)



In [16]:
# Save generator
#torch.save(generator.state_dict(), "cmgan_generator.pth")

# Optional: Save discriminator too
#torch.save(discriminator.state_dict(), "cmgan_discriminator.pth")


In [17]:
generator.load_state_dict(torch.load("cmgan_generator.pth", map_location='cpu'))  # or 'cuda' if using GPU
discriminator.load_state_dict(torch.load("cmgan_discriminator.pth", map_location='cpu'))


  generator.load_state_dict(torch.load("cmgan_generator.pth", map_location='cpu'))  # or 'cuda' if using GPU
  discriminator.load_state_dict(torch.load("cmgan_discriminator.pth", map_location='cpu'))


<All keys matched successfully>

In [18]:
import torchaudio
import torch
from pathlib import Path

def enhance_audio(generator, input_path, output_path, sr=16000, device='cuda'):
    generator.eval()
    input_path = Path(input_path)
    output_path = Path(output_path)

    # Load and preprocess
    waveform, file_sr = torchaudio.load(str(input_path))
    if file_sr != sr:
        resampler = torchaudio.transforms.Resample(orig_freq=file_sr, new_freq=sr)
        waveform = resampler(waveform)

    waveform = waveform.mean(dim=0, keepdim=True)  # convert to mono if stereo
    waveform = waveform.unsqueeze(0).to(device)    # [1, 1, T]

    # Forward pass through CMGAN
    with torch.no_grad():
        outputs = generator(waveform)
        enhanced = outputs['enhanced_waveform']     # [1, 1, T]
        enhanced = enhanced.squeeze(0).cpu()        # [1, T]

    # Save the output
    torchaudio.save(str(output_path), enhanced, sample_rate=sr)
    print(f"[✓] Enhanced audio saved to: {output_path}")


In [21]:
enhance_audio(generator, r"E:\filtered_output_freq_removal.wav", r"E:\final_cleaned.wav")

[✓] Enhanced audio saved to: E:\final_cleaned.wav


In [None]:
# Save
#torch.save(generator.state_dict(), "cmgan_generator.pth")
#torch.save(discriminator.state_dict(), "cmgan_discriminator.pth")

In [None]:
import torch
import torchaudio

# Dummy clean and enhanced waveforms
waveform = torch.randn(2, 1, 16000).cuda()  # (B=2, C=1, T=16000)
n_fft = 512
hop_length = 128
window = torch.hann_window(n_fft).cuda()

# STFT
clean_stft = torch.stft(waveform.squeeze(1), n_fft=n_fft, hop_length=hop_length,
                        return_complex=True, window=window)
enhanced_stft = torch.stft(waveform.squeeze(1), n_fft=n_fft, hop_length=hop_length,
                           return_complex=True, window=window)

clean_mag = clean_stft.abs()
enhanced_mag = enhanced_stft.abs()

# Pass through discriminator
with torch.no_grad():
    out = discriminator(clean_mag, enhanced_mag)
    print("Discriminator output shape:", out.shape)



Discriminator output shape: torch.Size([2, 1])
