In [None]:
!gdown 1Zv2rByeWFgZGvyxkZkf0Ff37QTCLoJGm
!gdown 1yWFeAkA7cfoFBIab115qFJCbw456MJmS

In [None]:
!pip install torch torchaudio torchmetrics

In [3]:
import os
import gc
import argparse
import torchaudio
import torch
import torch.nn.functional as F
import zipfile
import json

from torch import nn
from torchmetrics.text.wer import WordErrorRate
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from model import ConformerEncoder, LSTMDecoder
from utils import *

In [4]:
os.mkdir('common_voice')
os.mkdir('ViVOS')

In [None]:
!gdown 1enxwIwmqJ7d4uDMlQyaL3eSU_4CsI6PE -O common_voice/dev.json
!gdown 1BdVj69KW4YKqykfxCYa3vrS5R-HVEOKi -O common_voice/test.json
!gdown 15MnfOokem_HomGxxwn1i4xs8lrFQwR6i -O common_voice/train.json
!gdown 1wUw8JcTcjGbKWdKKbtpnx10Acd9fn9BS -O common_voice/voices.zip

In [None]:
!gdown 1oV2v0RBHX_Rqvra0YUrgoBgND0QC64Pc -O ViVOS/train.json
!gdown 1obDaRybTfcOaGrl6mhAb5bMpMESs2lr5 -O ViVOS/test.json
!gdown 1JoUgZ6uGPb5_iZTDinjF5pRzUvhk-4-n -O ViVOS/voices.zip

In [11]:
parser = argparse.ArgumentParser("conformer")
parser.add_argument('--data_dir', type=str, default='./data', help='location to download data')
parser.add_argument('--checkpoint_path', type=str, default='model_best.pt', help='path to store/load checkpoints')
parser.add_argument('--load_checkpoint', action='store_true', default=False, help='resume training from checkpoint')
parser.add_argument('--train_set', type=str, default='train-clean-100', help='train dataset')
parser.add_argument('--test_set', type=str, default='test-clean', help='test dataset')
parser.add_argument('--batch_size', type=int, default=4, help='batch size')
parser.add_argument('--warmup_steps', type=float, default=5000, help='Multiply by sqrt(d_model) to get max_lr')
parser.add_argument('--peak_lr_ratio', type=int, default=0.01, help='Number of warmup steps for LR scheduler')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id (optional)')
parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
parser.add_argument('--report_freq', type=int, default=100, help='training objective report frequency')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--use_amp', action='store_true', default=False, help='use mixed precision to train')
parser.add_argument('--attention_heads', type=int, default=4, help='number of heads to use for multi-head attention')
parser.add_argument('--d_input', type=int, default=80, help='dimension of the input (num filter banks)')
parser.add_argument('--d_encoder', type=int, default=128, help='dimension of the encoder')
parser.add_argument('--d_decoder', type=int, default=256, help='dimension of the decoder')
parser.add_argument('--encoder_layers', type=int, default=16, help='number of conformer blocks in the encoder')
parser.add_argument('--decoder_layers', type=int, default=1, help='number of decoder layers')
parser.add_argument('--conv_kernel_size', type=int, default=31, help='size of kernel for conformer convolution blocks')
parser.add_argument('--feed_forward_expansion_factor', type=int, default=2, help='expansion factor for conformer feed forward blocks')
parser.add_argument('--feed_forward_residual_factor', type=int, default=.3, help='residual factor for conformer feed forward blocks')
parser.add_argument('--dropout', type=float, default=0.05, help='dropout factor for conformer model')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='model weight decay (corresponds to L2 regularization)')
parser.add_argument('--variational_noise_std', type=float, default=1e-5, help='std of noise added to model weights for regularization')
parser.add_argument('--num_workers', type=int, default=2, help='num_workers for the dataloader')
parser.add_argument('--smart_batch', type=bool, default=True, help='Use smart batching for faster training')
parser.add_argument('--accumulate_iters', type=int, default=1, help='Number of iterations to accumulate gradients')
args, unknown = parser.parse_known_args()

In [None]:
def main():

    #Load data
    def read_json(file_path):
        """Đọc tệp JSON và trả về dữ liệu."""
        with open(file_path, 'r') as f:
            data = json.load(f)
        return data

    def unzip_voices(zip_path, extract_to):
        """Giải nén tệp âm thanh từ file zip vào thư mục chỉ định."""
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
            print(f"Đã giải nén tệp vào: {extract_to}")

        # Xử lý trường hợp thư mục lồng nhau
        nested_path = os.path.join(extract_to, 'voices')
        if os.path.isdir(nested_path):
            for item in os.listdir(nested_path):
                nested_item_path = os.path.join(nested_path, item)
                final_destination = os.path.join(extract_to, item)
                os.rename(nested_item_path, final_destination)
            os.rmdir(nested_path)
            print(f"Đã xử lý thư mục lồng nhau tại: {extract_to}")

    def create_voice_id_map(dataset):
        """Tạo bản đồ ánh xạ ID từ tên file âm thanh (hỗ trợ cả .wav và .mp3)."""
        return {
            value['voice'].replace('.wav', '').replace('.mp3', ''): value
            for key, value in dataset.items() if isinstance(value, dict) and 'voice' in value
        }

    common_voice_train = read_json('/kaggle/working/common_voice/train.json')
    common_voice_dev = read_json('/kaggle/working/common_voice/dev.json')
    common_voice_test = read_json('/kaggle/working/common_voice/test.json')
    vivos_train = read_json('/kaggle/working/ViVOS/train.json')
    vivos_test = read_json('/kaggle/working/ViVOS/test.json')

    common_voice_dataset = {**common_voice_train, **common_voice_dev, **common_voice_test}
    vivos_dataset = {**vivos_train, **vivos_test}

    print(f"Số lượng mẫu common_voice: {len(common_voice_dataset)}")
    print(f"Số lượng mẫu vivos: {len(vivos_dataset)}")

    common_voice_map = create_voice_id_map(common_voice_dataset)
    vivos_map = create_voice_id_map(vivos_dataset)

    class VoiceDataset(Dataset):
        def __init__(self, merged_data):
            """Khởi tạo với dữ liệu đã ghép."""
            self.merged_data = merged_data

        def __len__(self):
            return len(self.merged_data)

        def __getitem__(self, idx):
            """Trả về cặp âm thanh và văn bản."""
            audio_path, transcription = self.merged_data[idx]
            audio, sample_rate = torchaudio.load(audio_path)
            return audio, '', transcription, '', '', ''

    common_voice_zip_path = '/kaggle/working/common_voice/voices.zip'
    vivos_zip_path = '/kaggle/working/ViVOS/voices.zip'

    common_voice_data_dir = '/kaggle/working/common_voice/voices'
    vivos_data_dir = '/kaggle/working/ViVOS/voices'

    unzip_voices(common_voice_zip_path, common_voice_data_dir)
    unzip_voices(vivos_zip_path, vivos_data_dir)

    def merge_audio_with_text(audio_dir, text_map):
        """Ghép dữ liệu âm thanh và văn bản dựa trên ID và chỉ lấy những phần tử đã ghép thành công."""
        merged_data = []

        audio_files = [f for f in os.listdir(audio_dir) if f.endswith('.wav') or f.endswith('.mp3')]

        for audio_file in audio_files:
            audio_id = audio_file.replace('.wav', '').replace('.mp3', '')

            if audio_id in text_map:
                text_data = text_map[audio_id]
                transcription = text_data['script']

                if transcription:
                    audio_path = os.path.join(audio_dir, audio_file)
                    merged_data.append((audio_path, transcription))

        return merged_data

    common_voice_merged = merge_audio_with_text(common_voice_data_dir, common_voice_map)
    vivos_merged = merge_audio_with_text(vivos_data_dir, vivos_map)

    full_dataset = common_voice_merged + vivos_merged
    print(f"Kích thước tập dữ liệu tổng hợp: {len(full_dataset)}")

    voice_dataset = VoiceDataset(full_dataset)

    train_size = int(0.8 * len(voice_dataset))
    test_size = len(voice_dataset) - train_size
    train_data, test_data = random_split(voice_dataset, [train_size, test_size])

    print(f"Train data size: {len(train_data)}")
    print(f"Test data size: {len(test_data)}")

    gpu = False

    try:
        print("Starting data loader setup...")

        # if args.smart_batch:
        #     print('Sorting training data for smart batching...')
        #     sorted_train_inds = [ind for ind, _ in sorted(enumerate(train_data), key=lambda x: x[1][0].shape[1])]
        #     sorted_test_inds = [ind for ind, _ in sorted(enumerate(test_data), key=lambda x: x[1][0].shape[1])]
        #     train_loader = DataLoader(dataset=train_data,
        #                                     pin_memory=True,
        #                                     num_workers=args.num_workers,
        #                                     batch_sampler=BatchSampler(sorted_train_inds, batch_size=args.batch_size),
        #                                     collate_fn=lambda x: preprocess_example(x, 'train'))

        #     test_loader = DataLoader(dataset=test_data,
        #                                 pin_memory=True,
        #                                 num_workers=args.num_workers,
        #                                 batch_sampler=BatchSampler(sorted_test_inds, batch_size=args.batch_size),
        #                                 collate_fn=lambda x: preprocess_example(x, 'valid'))

        # else:
        print("Using regular batching...")
        train_loader = DataLoader(dataset=train_data,
                                  pin_memory=True,
                                  num_workers=0,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  collate_fn=lambda x: preprocess_example(x, 'train'))

        test_loader = DataLoader(dataset=test_data,
                                pin_memory=True,
                                num_workers=0,
                                batch_size=args.batch_size,
                                shuffle=False,
                                collate_fn=lambda x: preprocess_example(x, 'valid'))


        # Declare Models
        print("Initializing models...")
        encoder = ConformerEncoder(
            d_input=args.d_input,
            d_model=args.d_encoder,
            num_layers=args.encoder_layers,
            conv_kernel_size=args.conv_kernel_size,
            dropout=args.dropout,
            feed_forward_residual_factor=args.feed_forward_residual_factor,
            feed_forward_expansion_factor=args.feed_forward_expansion_factor,
            num_heads=args.attention_heads)

        decoder = LSTMDecoder(
            d_encoder=args.d_encoder,
            d_decoder=args.d_decoder,
            num_layers=args.decoder_layers)

        char_decoder = GreedyCharacterDecoder().eval()

        criterion = nn.CTCLoss(blank=8, zero_infinity=True)

        # Optimizer and scheduler
        optimizer = torch.optim.AdamW(
            list(encoder.parameters()) + list(decoder.parameters()),
            lr=0.01,
            betas=(.9, .98),
            eps=1e-05 if args.use_amp else 1e-09,
            weight_decay=args.weight_decay
            )

        scheduler = TransformerLrScheduler(optimizer, args.d_encoder, args.warmup_steps)

        # Print model size
        model_size(encoder, 'Encoder')
        model_size(decoder, 'Decoder')

        # GPU Setup
        print("Setting up GPU environment...")
        if torch.cuda.is_available():
            print('Using GPU')
            gpu = True
            torch.cuda.set_device(args.gpu)
            criterion = criterion.cuda()
            encoder = encoder.cuda()
            decoder = decoder.cuda()
            char_decoder = char_decoder.cuda()
            torch.cuda.empty_cache()
        else:
            print("GPU not available, using CPU")
            gpu = False

        # Mixed Precision Setup
        if args.use_amp:
            print('Using Mixed Precision')
        grad_scaler = torch.amp.GradScaler(enabled=args.use_amp)

        # Initialize Checkpoint
        print("Initializing checkpoint...")
        if args.load_checkpoint:
            start_epoch, best_loss = load_checkpoint(encoder, decoder, optimizer, scheduler, args.checkpoint_path)
            print(f'Resuming training from checkpoint starting at epoch {start_epoch}.')
        else:
            start_epoch = 0
            best_loss = float('inf')

        # Train Loop
        print("Starting training loop...")
        optimizer.zero_grad()
        for epoch in range(start_epoch, args.epochs):
            torch.cuda.empty_cache()
            print(f"Starting epoch {epoch + 1} / {args.epochs}")

            # Variational noise for regularization
            add_model_noise(encoder, std=args.variational_noise_std, gpu=gpu)
            add_model_noise(decoder, std=args.variational_noise_std, gpu=gpu)

            # Train/Validation loops
            print("Training step...")
            wer, loss = train(encoder, decoder, char_decoder, optimizer, scheduler, criterion, grad_scaler, train_loader, args, gpu=gpu)
            print("Validation step...")
            valid_wer, valid_loss = validate(encoder, decoder, char_decoder, criterion, test_loader, args, gpu=gpu)

            print(f'Epoch {epoch} - Valid WER: {valid_wer}%, Valid Loss: {valid_loss}, Train WER: {wer}%, Train Loss: {loss}')

            # Save checkpoint
            if valid_loss <= best_loss:
                print('Validation loss improved, saving checkpoint.')
                best_loss = valid_loss
                save_checkpoint(encoder, decoder, optimizer, scheduler, valid_loss, epoch+1, args.checkpoint_path)

        print("Training loop completed successfully.")

    except Exception as e:
        import traceback
        print("An error occurred:")
        traceback.print_exc()
        print("Error details:", str(e))

        # If error occurs, set up GPU again
        print("Setting up GPU environment...")
        if torch.cuda.is_available():
            print('Using GPU')
            gpu = True
            torch.cuda.set_device(args.gpu)
            criterion = criterion.cuda()
            encoder = encoder.cuda()
            decoder = decoder.cuda()
            char_decoder = char_decoder.cuda()
            torch.cuda.empty_cache()
        else:
            print("GPU not available, using CPU")
            gpu = False

        # Mixed Precision Setup
        if args.use_amp:
            print('Using Mixed Precision')
        grad_scaler = torch.amp.GradScaler(enabled=args.use_amp)

        # Initialize Checkpoint
        print("Initializing checkpoint...")
        if args.load_checkpoint:
            start_epoch, best_loss = load_checkpoint(encoder, decoder, optimizer, scheduler, args.checkpoint_path)
            print(f'Resuming training from checkpoint starting at epoch {start_epoch}.')
        else:
            start_epoch = 0
            best_loss = float('inf')

        # Training loop can be resumed after error handling
        optimizer.zero_grad()
        for epoch in range(start_epoch, args.epochs):
            torch.cuda.empty_cache()
            print(f"Starting epoch {epoch + 1} / {args.epochs}")

            # Variational noise for regularization
            add_model_noise(encoder, std=args.variational_noise_std, gpu=gpu)
            add_model_noise(decoder, std=args.variational_noise_std, gpu=gpu)

            # Train/Validation loops
            print("Training step...")
            wer, loss = train(encoder, decoder, char_decoder, optimizer, scheduler, criterion, grad_scaler, train_loader, args, gpu=gpu)
            print("Validation step...")
            valid_wer, valid_loss = validate(encoder, decoder, char_decoder, criterion, test_loader, args, gpu=gpu)

            print(f'Epoch {epoch} - Valid WER: {valid_wer}%, Valid Loss: {valid_loss}, Train WER: {wer}%, Train Loss: {loss}')

            # Save checkpoint
            if valid_loss <= best_loss:
                print('Validation loss improved, saving checkpoint.')
                best_loss = valid_loss
                save_checkpoint(encoder, decoder, optimizer, scheduler, valid_loss, epoch+1, args.checkpoint_path)

        print("Training loop completed successfully.")

def train(encoder, decoder, char_decoder, optimizer, scheduler, criterion, grad_scaler, train_loader, args, gpu=True):
  ''' Run a single training epoch '''

  wer = WordErrorRate()
  error_rate = AvgMeter()
  avg_loss = AvgMeter()
  text_transform = TextTransform()

  encoder.train()
  decoder.train()
  for i, batch in enumerate(train_loader):
    scheduler.step()
    gc.collect()
    spectrograms, labels, input_lengths, label_lengths, references, mask = batch

    # Move to GPU
    if gpu:
      spectrograms = spectrograms.cuda()
      labels = labels.cuda()
      input_lengths = torch.tensor(input_lengths).cuda()
      label_lengths = torch.tensor(label_lengths).cuda()
      mask = mask.cuda()

    # Update models
    with torch.amp.autocast('cuda', enabled=args.use_amp):

      outputs = encoder(spectrograms, mask)

      outputs = decoder(outputs)
      if torch.any(torch.isnan(outputs)):
        print("NaN detected in outputs before softmax")

      loss = criterion(F.log_softmax(outputs, dim=-1).transpose(0, 1), labels, input_lengths, label_lengths)
      print(f"Step {i+1} - Loss: {loss.item()}")
      if torch.isnan(loss).any():
        print(f"NaN detected in loss at step {i+1}")


    grad_scaler.scale(loss).backward()

    torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), max_norm=1.0)
    if (i+1) % args.accumulate_iters == 0:
      grad_scaler.step(optimizer)
      grad_scaler.update()
      optimizer.zero_grad()
    avg_loss.update(loss.detach().item())

    # Predict words, compute WER
    inds = char_decoder(outputs.detach())
    predictions = []
    for sample in inds:
      predictions.append(text_transform.int_to_text(sample))
    error_rate.update(wer(predictions, references) * 100)

    # Print metrics and predictions
    if (i+1) % args.report_freq == 0:
      print(f'Step {i+1} - Avg WER: {error_rate.avg}%, Avg Loss: {avg_loss.avg}')
      print('Sample Predictions: ', predictions)
    del spectrograms, labels, input_lengths, label_lengths, references, outputs, inds, predictions
  return error_rate.avg, avg_loss.avg

def validate(encoder, decoder, char_decoder, criterion, test_loader, args, gpu=True):
  ''' Evaluate model on test dataset. '''

  avg_loss = AvgMeter()
  error_rate = AvgMeter()
  wer = WordErrorRate()
  text_transform = TextTransform()

  encoder.eval()
  decoder.eval()
  for i, batch in enumerate(test_loader):
    gc.collect()
    spectrograms, labels, input_lengths, label_lengths, references, mask = batch

    # Move to GPU
    if gpu:
      spectrograms = spectrograms.cuda()
      labels = labels.cuda()
      input_lengths = torch.tensor(input_lengths).cuda()
      label_lengths = torch.tensor(label_lengths).cuda()
      mask = mask.cuda()

    with torch.no_grad():
      with torch.amp.autocast('cuda', enabled=args.use_amp):
        outputs = encoder(spectrograms, mask)
        outputs = decoder(outputs)
        loss = criterion(F.log_softmax(outputs, dim=-1).transpose(0, 1), labels, input_lengths, label_lengths)
      avg_loss.update(loss.item())

      inds = char_decoder(outputs.detach())
      predictions = []
      for sample in inds:
        predictions.append(text_transform.int_to_text(sample))
      error_rate.update(wer(predictions, references) * 100)
  return error_rate.avg, avg_loss.avg


if __name__ == '__main__':
  main()