# Speech Enhancement Generative Adversarial Network (SEGAN) 🪿

This notebook contains an implementation of the [SEGAN](https://arxiv.org/pdf/1703.09452) paper.

## Imports

In [1]:
import os

from src.segan import SEGAN
from src.util.consts import TASK_1, TASK_2, TASK_3
from src.util.device import set_device

%load_ext autoreload
%autoreload 2

## Setup

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = "2"
device = set_device()

Using device: NVIDIA A100-SXM4-40GB (0)


The following code sets the hyperparameters and loads the training and validation datasets.

In [3]:
levels = TASK_1
hyperparameters = {
    "lr": 0.0001,
    "l1_mag": 100,
    "batch_size": 32,
    "diffusion": True,
}

In [None]:
with open('../val_paths.txt', "r") as f:
        val_paths = [line.split(',') for line in f.read().splitlines()]

segan = SEGAN(levels=levels, 
              hyperparameters=hyperparameters, 
              diffusion=True,
              attention=True,
              spectral_norm=True,
              val_paths=val_paths,
              device=device)

Initialized dataset with 4217 files from 7 tasks


## Training

In [None]:
segan.learn(num_episodes=3_000)

Starting run 2024-10-21_16-41-22


[34m[1mwandb[0m: Currently logged in as: [33mpascal12[0m. Use [1m`wandb login --relogin`[0m to force relogin


----------------------------------------------
episode: 0
iteration: 0
lr: 0.0001
generator_loss: 12.872396469116211
discriminator_loss: 0.21919359266757965
----------------------------------------------
episode: 0
iteration: 1
lr: 0.0001
generator_loss: 11.111681938171387
discriminator_loss: 0.12269951403141022
----------------------------------------------
episode: 0
iteration: 2
lr: 0.0001
generator_loss: 8.983762741088867
discriminator_loss: 0.09288657456636429
----------------------------------------------
episode: 0
iteration: 3
lr: 0.0001
generator_loss: 7.415934085845947
discriminator_loss: 0.6098226308822632
----------------------------------------------
episode: 0
iteration: 4
lr: 0.0001
generator_loss: 6.425756454467773
discriminator_loss: 0.5815525054931641
----------------------------------------------
episode: 0
iteration: 5
lr: 0.0001
generator_loss: 5.2579498291015625
discriminator_loss: 0.6021736860275269
----------------------------------------------
episode: 0
iterat

## Evaluation

The trained generator is applied to the test set and the chunk-wise & file-wise reconstruction loss as well as the character error rate are returned.

In [None]:
segan.test()

## Save Models

In [None]:
segan.write()