# 🎤 DTLN Model Training on Google Colab
## Vietnamese Speech Enhancement with VIVOS Dataset

This notebook trains the DTLN (Dual-Signal Transformation LSTM Network) model for speech enhancement using VIVOS dataset.

**Features:**
- ✅ Complete self-contained notebook (all code embedded)
- ✅ Automatic checkpoint saving to Google Drive after each epoch  
- ✅ Resume training from last checkpoint
- ✅ Training progress logging
- ✅ GPU support
- ✅ No external files needed (model.py embedded)

**Dataset:**
- Created from VIVOS (clean speech) + DNS noise
- Already prepared in `vivos_datasets/` directory

## 1. Mount Google Drive

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set working directory
DRIVE_PATH = '/content/drive/MyDrive/DTLN'
os.makedirs(DRIVE_PATH, exist_ok=True)
os.chdir(DRIVE_PATH)

print(f"✅ Mounted Google Drive")
print(f"📁 Working directory: {os.getcwd()}")

## 2. Install Dependencies

In [None]:
!pip install -q soundfile wavinfo

import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

## 3. Configuration

In [None]:
# ==================== CONFIGURATION ====================
# Dataset paths (relative to DRIVE_PATH)
PATH_TO_TRAIN_MIX = 'vivos_datasets/train/noisy'
PATH_TO_TRAIN_SPEECH = 'vivos_datasets/train/clean'
PATH_TO_VAL_MIX = 'vivos_datasets/val/noisy'
PATH_TO_VAL_SPEECH = 'vivos_datasets/val/clean'

# Checkpoint directory
CHECKPOINT_DIR = 'vivos_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Training run name
RUN_NAME = 'DTLN_vivos'

# Model hyperparameters
BATCH_SIZE = 64
MAX_EPOCHS = 50
LEARNING_RATE = 1e-3
SAMPLE_LENGTH_SECONDS = 15

# Resume training from checkpoint?
RESUME_FROM_CHECKPOINT = True

print("✅ Configuration loaded")
print(f"📊 Batch size: {BATCH_SIZE}")
print(f"📊 Max epochs: {MAX_EPOCHS}")
print(f"📊 Learning rate: {LEARNING_RATE}")
print(f"📊 Sample length: {SAMPLE_LENGTH_SECONDS}s")
print(f"💾 Checkpoint dir: {CHECKPOINT_DIR}")
print(f"🔄 Resume training: {RESUME_FROM_CHECKPOINT}")

## 4. DTLN Model Code (Embedded)

Complete DTLN model implementation embedded directly in notebook.

In [None]:
import os, fnmatch
import csv
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout, \
    Lambda, Input, Multiply, Layer, Conv1D
from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, \
    EarlyStopping, ModelCheckpoint
import tensorflow as tf
import soundfile as sf
from wavinfo import WavInfoReader
from random import shuffle, seed
import numpy as np


class audio_generator():
    '''
    Class to create a Tensorflow dataset based on an iterator from a large scale 
    audio dataset. This audio generator only supports single channel audio files.
    '''
    
    def __init__(self, path_to_input, path_to_s1, len_of_samples, fs, train_flag=False):
        '''
        Constructor of the audio generator class.
        '''
        self.path_to_input = path_to_input
        self.path_to_s1 = path_to_s1
        self.len_of_samples = len_of_samples
        self.fs = fs
        self.train_flag=train_flag
        self.count_samples()
        self.create_tf_data_obj()
        
    def count_samples(self):
        '''Method to list the data of the dataset and count the number of samples.'''
        self.file_names = fnmatch.filter(os.listdir(self.path_to_input), '*.wav')
        self.total_samples = 0
        for file in self.file_names:
            info = WavInfoReader(os.path.join(self.path_to_input, file))
            self.total_samples = self.total_samples + \
                int(np.fix(info.data.frame_count/self.len_of_samples))
    
    def create_generator(self):
        '''Method to create the iterator.'''
        if self.train_flag:
            shuffle(self.file_names)
        for file in self.file_names:
            noisy, fs_1 = sf.read(os.path.join(self.path_to_input, file))
            speech, fs_2 = sf.read(os.path.join(self.path_to_s1, file))  # Same filename
            
            if fs_1 != self.fs or fs_2 != self.fs:
                raise ValueError('Sampling rates do not match.')
            if noisy.ndim != 1 or speech.ndim != 1:
                raise ValueError('Too many audio channels.')
            
            num_samples = int(np.fix(noisy.shape[0]/self.len_of_samples))
            for idx in range(num_samples):
                in_dat = noisy[int(idx*self.len_of_samples):int((idx+1)*self.len_of_samples)]
                tar_dat = speech[int(idx*self.len_of_samples):int((idx+1)*self.len_of_samples)]
                yield in_dat.astype('float32'), tar_dat.astype('float32')

    def create_tf_data_obj(self):
        '''Method to create the tf.data.Dataset.'''
        self.tf_data_set = tf.data.Dataset.from_generator(
                        self.create_generator,
                        (tf.float32, tf.float32),
                        output_shapes=(tf.TensorShape([self.len_of_samples]), 
                                       tf.TensorShape([self.len_of_samples])),
                        args=None)


class InstantLayerNormalization(Layer):
    '''Instant layer normalization layer'''
    def __init__(self, **kwargs):
        super(InstantLayerNormalization, self).__init__(**kwargs)
        self.epsilon = 1e-7 
        self.gamma = None
        self.beta = None

    def build(self, input_shape):
        shape = input_shape[-1:]
        self.gamma = self.add_weight(shape=shape, initializer='ones', trainable=True, name='gamma')
        self.beta = self.add_weight(shape=shape, initializer='zeros', trainable=True, name='beta')

    def call(self, inputs):
        mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)
        variance = tf.math.reduce_mean(tf.math.square(inputs - mean), axis=[-1], keepdims=True)
        std = tf.math.sqrt(variance + self.epsilon)
        outputs = (inputs - mean) / std
        outputs = outputs * self.gamma + self.beta
        return outputs


class DTLN_model():
    '''Class to create and train the DTLN model'''
    
    def __init__(self):
        self.cost_function = self.snr_cost
        self.model = []
        # Default parameters
        self.fs = 16000
        self.batchsize = 64
        self.len_samples = 15
        self.activation = 'sigmoid'
        self.numUnits = 128
        self.numLayer = 2
        self.blockLen = 512
        self.block_shift = 128
        self.dropout = 0.25
        self.lr = 1e-3 
        self.max_epochs = 50
        self.encoder_size = 256
        self.eps = 1e-7
        
        # Set seeds
        os.environ['PYTHONHASHSEED']=str(42)
        seed(42)
        np.random.seed(42)
        tf.random.set_seed(42)
        
        # GPU memory growth
        physical_devices = tf.config.experimental.list_physical_devices('GPU')
        if len(physical_devices) > 0:
            for device in physical_devices:
                tf.config.experimental.set_memory_growth(device, enable=True)

    @staticmethod
    def snr_cost(s_estimate, s_true):
        '''SNR cost function'''
        snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \
            (tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True)+1e-7)
        num = tf.math.log(snr) 
        denom = tf.math.log(tf.constant(10, dtype=num.dtype))
        loss = -10*(num / (denom))
        return loss

    def lossWrapper(self):
        '''Wrapper for loss function'''
        def lossFunction(y_true,y_pred):
            loss = tf.squeeze(self.cost_function(y_pred,y_true))
            loss = tf.reduce_mean(loss)
            return loss
        return lossFunction

    def stftLayer(self, x):
        '''STFT layer'''
        frames = tf.signal.frame(x, self.blockLen, self.block_shift)
        stft_dat = tf.signal.rfft(frames)
        mag = tf.abs(stft_dat)
        phase = tf.math.angle(stft_dat)
        return [mag, phase]
    
    def fftLayer(self, x):
        '''FFT layer'''
        frame = tf.expand_dims(x, axis=1)
        stft_dat = tf.signal.rfft(frame)
        mag = tf.abs(stft_dat)
        phase = tf.math.angle(stft_dat)
        return [mag, phase]

    def ifftLayer(self, x):
        '''Inverse FFT layer'''
        s1_stft = (tf.cast(x[0], tf.complex64) * tf.exp( (1j * tf.cast(x[1], tf.complex64))))
        return tf.signal.irfft(s1_stft)  

    def overlapAddLayer(self, x):
        '''Overlap and add layer'''
        return tf.signal.overlap_and_add(x, self.block_shift)

    def seperation_kernel(self, num_layer, mask_size, x, stateful=False):
        '''Separation kernel with LSTM layers'''
        for idx in range(num_layer):
            x = LSTM(self.numUnits, return_sequences=True, stateful=stateful)(x)
            if idx<(num_layer-1):
                x = Dropout(self.dropout)(x)
        mask = Dense(mask_size)(x)
        mask = Activation(self.activation)(mask)
        return mask

    def build_DTLN_model(self, norm_stft=False):
        '''Build DTLN model'''
        time_dat = Input(batch_shape=(None, None))
        mag,angle = Lambda(self.stftLayer)(time_dat)
        
        if norm_stft:
            mag_norm = InstantLayerNormalization()(tf.math.log(mag + 1e-7))
        else:
            mag_norm = mag
        
        mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm)
        estimated_mag = Multiply()([mag, mask_1])
        estimated_frames_1 = Lambda(self.ifftLayer)([estimated_mag,angle])
        encoded_frames = Conv1D(self.encoder_size,1,strides=1,use_bias=False)(estimated_frames_1)
        encoded_frames_norm = InstantLayerNormalization()(encoded_frames)
        mask_2 = self.seperation_kernel(self.numLayer, self.encoder_size, encoded_frames_norm)
        estimated = Multiply()([encoded_frames, mask_2]) 
        decoded_frames = Conv1D(self.blockLen, 1, padding='causal',use_bias=False)(estimated)
        estimated_sig = Lambda(self.overlapAddLayer)(decoded_frames)
        
        self.model = Model(inputs=time_dat, outputs=estimated_sig)
        print(self.model.summary())

    def compile_model(self):
        '''Compile model with optimizer and loss'''
        optimizerAdam = keras.optimizers.Adam(learning_rate=self.lr, clipnorm=3.0)
        self.model.compile(loss=self.lossWrapper(), optimizer=optimizerAdam)

print("✅ DTLN model code loaded")

## 5. Check Dataset

In [None]:
import fnmatch

def check_dataset_path(path, name):
    """Check if dataset path exists and count files"""
    if os.path.exists(path):
        files = fnmatch.filter(os.listdir(path), '*.wav')
        print(f"✅ {name}: {len(files)} files found at {path}")
        return len(files)
    else:
        print(f"❌ {name}: Path not found - {path}")
        return 0

print("🔍 Checking datasets...\n")
train_noisy = check_dataset_path(PATH_TO_TRAIN_MIX, "Training Noisy")
train_clean = check_dataset_path(PATH_TO_TRAIN_SPEECH, "Training Clean")
val_noisy = check_dataset_path(PATH_TO_VAL_MIX, "Validation Noisy")
val_clean = check_dataset_path(PATH_TO_VAL_SPEECH, "Validation Clean")

if train_noisy > 0 and train_clean > 0 and val_noisy > 0 and val_clean > 0:
    print("\n✅ All dataset paths are valid!")
    print(f"\n📊 Dataset info:")
    print(f"   Train pairs: {train_noisy:,}")
    print(f"   Val pairs:   {val_noisy:,}")
    print(f"   Total pairs: {train_noisy + val_noisy:,}")
else:
    print("\n⚠️ Warning: Some dataset paths are missing or empty!")

## 6. Create and Build Model

In [None]:
# Create model instance
model_trainer = DTLN_model()

# Set custom parameters
model_trainer.batchsize = BATCH_SIZE
model_trainer.max_epochs = MAX_EPOCHS
model_trainer.lr = LEARNING_RATE
model_trainer.len_samples = SAMPLE_LENGTH_SECONDS

# Build the model
print("🏗️ Building DTLN model...")
model_trainer.build_DTLN_model(norm_stft=False)
print("\n✅ Model built successfully!")

## 7. Load Checkpoint (if resuming)

In [None]:
import glob
import re

def get_latest_checkpoint(checkpoint_dir, run_name):
    """Find the latest checkpoint file"""
    pattern = os.path.join(checkpoint_dir, f"{run_name}_epoch_*.weights.weights.h5")
    checkpoints = glob.glob(pattern)
    
    if not checkpoints:
        main_checkpoint = os.path.join(checkpoint_dir, f"{run_name}.weights.h5")
        if os.path.exists(main_checkpoint):
            return main_checkpoint, 0
        return None, 0
    
    checkpoint_epochs = []
    for cp in checkpoints:
        match = re.search(r'epoch_(\d+)', cp)
        if match:
            checkpoint_epochs.append((int(match.group(1)), cp))
    
    if checkpoint_epochs:
        checkpoint_epochs.sort(reverse=True)
        latest_epoch, latest_checkpoint = checkpoint_epochs[0]
        return latest_checkpoint, latest_epoch
    
    return None, 0

# Check for existing checkpoints
initial_epoch = 0

if RESUME_FROM_CHECKPOINT:
    latest_checkpoint, checkpoint_epoch = get_latest_checkpoint(CHECKPOINT_DIR, RUN_NAME)
    
    if latest_checkpoint:
        print(f"📥 Loading checkpoint from: {latest_checkpoint}")
        print(f"📊 Resuming from epoch: {checkpoint_epoch}")
        try:
            model_trainer.model.load_weights(latest_checkpoint)
            initial_epoch = checkpoint_epoch
            print("✅ Checkpoint loaded successfully!")
        except Exception as e:
            print(f"⚠️ Failed to load checkpoint: {e}")
            print("Starting training from scratch...")
            initial_epoch = 0
    else:
        print("ℹ️ No checkpoint found. Starting training from scratch...")
else:
    print("ℹ️ Resume from checkpoint disabled. Starting fresh training...")

print(f"\n🏁 Will start training from epoch: {initial_epoch + 1}")

## 8. Compile Model

In [None]:
print("⚙️ Compiling model...")
model_trainer.compile_model()
print("✅ Model compiled successfully!")

## 9. Setup Callbacks

In [None]:
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau, CSVLogger, EarlyStopping
import shutil

class DriveCheckpointCallback(Callback):
    """Custom callback to save checkpoint to Google Drive after each epoch"""
    
    def __init__(self, checkpoint_dir, run_name):
        super().__init__()
        self.checkpoint_dir = checkpoint_dir
        self.run_name = run_name
        self.best_val_loss = float('inf')
        
    def on_epoch_end(self, epoch, logs=None):
        # Save epoch checkpoint
        epoch_checkpoint_path = os.path.join(
            self.checkpoint_dir, 
            f"{self.run_name}_epoch_{epoch+1:03d}.weights.h5"
        )
        
        print(f"\n💾 Saving checkpoint to Google Drive: {epoch_checkpoint_path}")
        self.model.save_weights(epoch_checkpoint_path)
        
        # Save as latest
        latest_checkpoint_path = os.path.join(
            self.checkpoint_dir,
            f"{self.run_name}_latest.weights.h5"
        )
        shutil.copy(epoch_checkpoint_path, latest_checkpoint_path)
        
        # Save best model
        val_loss = logs.get('val_loss')
        if val_loss and val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            best_checkpoint_path = os.path.join(
                self.checkpoint_dir,
                f"{self.run_name}_best.weights.h5"
            )
            shutil.copy(epoch_checkpoint_path, best_checkpoint_path)
            print(f"⭐ New best model saved! Val Loss: {val_loss:.4f}")
        
        # Keep only last 3 epoch checkpoints
        self._cleanup_old_checkpoints(keep_last=3)
        print(f"✅ Checkpoint saved successfully!")
    
    def _cleanup_old_checkpoints(self, keep_last=3):
        pattern = os.path.join(self.checkpoint_dir, f"{self.run_name}_epoch_*.weights.h5")
        checkpoints = glob.glob(pattern)
        
        if len(checkpoints) > keep_last:
            checkpoints.sort(key=os.path.getmtime)
            for old_checkpoint in checkpoints[:-keep_last]:
                try:
                    os.remove(old_checkpoint)
                    print(f"🗑️ Removed old checkpoint: {os.path.basename(old_checkpoint)}")
                except Exception as e:
                    print(f"⚠️ Failed to remove {old_checkpoint}: {e}")

# Create callbacks
save_path = os.path.join(CHECKPOINT_DIR, RUN_NAME)
os.makedirs(save_path, exist_ok=True)

drive_checkpoint_callback = DriveCheckpointCallback(CHECKPOINT_DIR, RUN_NAME)
csv_logger = CSVLogger(os.path.join(save_path, f'training_{RUN_NAME}.log'), append=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-10, cooldown=1, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode='auto', baseline=None)

callbacks = [
    drive_checkpoint_callback,
    csv_logger,
    reduce_lr,
    early_stopping
]

print("✅ Training callbacks configured!")

## 10. Prepare Data Generators

In [None]:
# Calculate sample length
len_in_samples = int(np.fix(
    model_trainer.fs * model_trainer.len_samples / model_trainer.block_shift
) * model_trainer.block_shift)

print(f"🎵 Audio sample length: {len_in_samples} samples ({len_in_samples/model_trainer.fs:.2f} seconds)")
print(f"\n📊 Creating data generators...")

# Create training data generator
print("   Loading training data...")
generator_input = audio_generator(
    PATH_TO_TRAIN_MIX,
    PATH_TO_TRAIN_SPEECH,
    len_in_samples,
    model_trainer.fs,
    train_flag=True
)

dataset = generator_input.tf_data_set
dataset = dataset.batch(model_trainer.batchsize, drop_remainder=True).repeat()
steps_train = generator_input.total_samples // model_trainer.batchsize

print(f"   ✅ Training samples: {generator_input.total_samples:,}")
print(f"   ✅ Training steps per epoch: {steps_train:,}")

# Create validation data generator  
print("\n   Loading validation data...")
generator_val = audio_generator(
    PATH_TO_VAL_MIX,
    PATH_TO_VAL_SPEECH,
    len_in_samples,
    model_trainer.fs
)

dataset_val = generator_val.tf_data_set
dataset_val = dataset_val.batch(model_trainer.batchsize, drop_remainder=True).repeat()
steps_val = generator_val.total_samples // model_trainer.batchsize

print(f"   ✅ Validation samples: {generator_val.total_samples:,}")
print(f"   ✅ Validation steps: {steps_val:,}")

print("\n✅ Data generators ready!")

## 11. Start Training 🚀

In [None]:
print("="*60)
print("🚀 STARTING TRAINING")
print("="*60)
print(f"📊 Configuration:")
print(f"   Batch size: {model_trainer.batchsize}")
print(f"   Max epochs: {model_trainer.max_epochs}")
print(f"   Initial epoch: {initial_epoch + 1}")
print(f"   Learning rate: {model_trainer.lr}")
print(f"   Training steps: {steps_train:,}")
print(f"   Validation steps: {steps_val:,}")
print(f"\n💾 Checkpoints will be saved to: {CHECKPOINT_DIR}")
print("="*60 + "\n")

# Start training
history = model_trainer.model.fit(
    x=dataset,
    batch_size=None,
    steps_per_epoch=steps_train,
    epochs=model_trainer.max_epochs,
    initial_epoch=initial_epoch,
    verbose=1,
    validation_data=dataset_val,
    validation_steps=steps_val,
    callbacks=callbacks,
)

print("\n" + "="*60)
print("✅ TRAINING COMPLETED!")
print("="*60)

## 12. Plot Training History

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss (Log Scale)')
plt.ylabel('Loss (log)')
plt.xlabel('Epoch')
plt.yscale('log')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, f'{RUN_NAME}_training_history.png'), dpi=150)
plt.show()

print(f"✅ Training history plot saved!")

## 13. Save Final Model

In [None]:
# Save final model weights
final_model_path = os.path.join(CHECKPOINT_DIR, f"{RUN_NAME}_final.weights.h5")
model_trainer.model.save_weights(final_model_path)
print(f"✅ Final model saved to: {final_model_path}")

print("\n" + "="*60)
print("  TRAINING SUMMARY")
print("="*60)
print(f"✅ Training completed successfully!")
print(f"\n📁 All files saved to Google Drive:")
print(f"   Location: {CHECKPOINT_DIR}")
print(f"\n💾 Checkpoint files:")
print(f"   • Best model: {RUN_NAME}_best.weights.h5")
print(f"   • Final model: {RUN_NAME}_final.weights.h5")
print(f"   • Latest checkpoint: {RUN_NAME}_latest.weights.h5")
print(f"\n📈 Training log:")
print(f"   • training_{RUN_NAME}.log")
print(f"\n🎨 Visualization:")
print(f"   • {RUN_NAME}_training_history.png")
print("\n" + "="*60)
print("✨ To resume training later, just run this notebook again!")
print("✨ Make sure RESUME_FROM_CHECKPOINT = True")
print("="*60 + "\n")