# Import modules

In [1]:
import torch
from torch.utils.data import ConcatDataset, DataLoader
import torch.optim as optim

In [2]:
!pip install speechbrain

Collecting speechbrain
  Downloading speechbrain-1.0.2-py3-none-any.whl.metadata (23 kB)
Collecting hyperpyyaml (from speechbrain)
  Downloading HyperPyYAML-1.2.2-py3-none-any.whl.metadata (7.6 kB)
Collecting ruamel.yaml>=0.17.28 (from hyperpyyaml->speechbrain)
  Downloading ruamel.yaml-0.18.8-py3-none-any.whl.metadata (22 kB)
Collecting ruamel.yaml.clib>=0.2.7 (from ruamel.yaml>=0.17.28->hyperpyyaml->speechbrain)
  Downloading ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.7 kB)
Downloading speechbrain-1.0.2-py3-none-any.whl (824 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m824.8/824.8 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading HyperPyYAML-1.2.2-py3-none-any.whl (16 kB)
Downloading ruamel.yaml-0.18.8-py3-none-any.whl (117 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.5/117.5 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ruamel.yaml.clib-0.2.12-cp310-cp31

# Dataset

In [None]:
# convert_mp3_to_wav("Vietnamese-Speech-to-Text-datasets//Common-Voice")
# convert_mp3_to_wav("Vietnamese-Speech-to-Text-datasets//ViVOS")

## Load dataset

In [3]:
import json
import torch
import itertools
import re
import string

# from dataset.text.word_decomposition import is_Vietnamese, decompose_non_vietnamese_word, compose_word


class Vocabulary:
    def __init__(self, paths):
        self.make_vocab(paths)
        self.make_idx()

    def __len__(self):
        return len(self.vocab) + 5

    def __getitem__(self, word):
        return self.word2idx.get(word, self.unk_idx)

    def __contains__(self, word):
        return word in self.word2idx

    def make_vocab(self, paths):
        self.vocab = set()
        self.max_len = 0
        for path in paths:
            with open(path, mode="r", encoding="utf-8") as file:
                data = json.load(file)
            data = list(data.items())
            for index in range(len(data)):
                sentence = data[index][1]['script']
                tokens = sentence.strip().split()
                if len(tokens) > self.max_len:
                    self.max_len = len(tokens)
                self.vocab.update(tokens)
        self.vocab = list(self.vocab)

    def make_idx(self):
        self.word2idx = {w: i for i, w in enumerate(self.vocab, 5)}
        self.pad_idx = 0
        self.bos_idx = 1
        self.eos_idx = 2
        self.blank_idx = 3
        self.unk_idx = 4
        self.idx2word = {i: w for w, i in self.word2idx.items()}

    def encode_script(self, script):
        encoded_script = []
        script = script.lower()
        pattern = f"[{re.escape(string.punctuation)}]"
        script = re.sub(pattern, "", script)
        words = script.split()
        for word in words:
            encoded_word = self.word2idx.get(word, self.unk_idx)
            encoded_script.append(encoded_word)
        return torch.tensor(encoded_script).long()


# class PhonemeVocabv2:
#     def __init__(self):
#         self.pad_idx = 0
#         self.bos_idx = 1
#         self.eos_idx = 2
#         self.blank_idx = 3

#         onsets = [
#             'ngh', 'tr', 'th', 'ph', 'nh', 'ng', 'kh',
#             'gi', 'gh', 'ch', 'q', 'đ', 'x', 'v', 't',
#             's', 'r', 'n', 'm', 'l', 'k', 'h', 'g', 'd',
#             'c', 'b'
#         ]
#         rhymes = [
#             # a
#             "a", "ac", "ach", "ai",
#             "am", "an", "ang", "anh",
#             "ao", "ap", "at", "ay", "au",
#             # ă
#             "ă", "ăc", "ăm", "ăn", "ăng", "ăp", "ăt",
#             # â
#             "â", "âc", "âm", "ân", "âng",
#             "âp", "ât", "âu", "ây",
#             # e
#             "e", "ec", "em", "en",
#             "eng", "eo", "ep", "et",
#             # ê
#             "ê", "êch", "êm", "ên",
#             "ênh", "êp", "êt", "êu",
#             # i
#             "i", "ia", "ich", "iêc", "iêm", "iên",
#             "iêng", "iêp", "iêt", "iêu", "im", "in",
#             "inh", "ip", "it", "iu",
#             # o
#             "o", "oa", "oac", "oach", "oai",
#             "oam", "oan", "oang", "oanh",
#             "oao", "oap", "oat", "oay",
#             "oăc", "oăm", "oăn", "oăng",
#             "oăt", "oc", "oe", "oen","oeo",
#             "oet", "oi", "om", "on", "ong",
#             "ooc", "oong", "op", "ot",
#             # ô
#             "ô", "ôc", "ôi",
#             "ôm", "ôn", "ông",
#             "ôp", "ôt",
#             # ơ
#             "ơ", "ơi", "ơm",
#             "ơn", "ơp", "ơt",
#             # u
#             "u", "ua", "uân", "uâng", "uât",
#             "uây", "uc", "uê", "uêch", "uênh",
#             "ui", "um", "un", "ung", "uơ", "uôc",
#             "uôi", "uôm", "uôn", "uông", "uôt",
#             "up", "ut", "uy", "uya", "uych",
#             "uyên", "uyêt", "uyn", "uynh",
#             "uyp", "uyt", "uyu",

#             "uach", "uai", "uan", "uang", "uanh", "uao", "uat", "uau", "uay",
#             "uăc", "uăm", "uăn", "uăng", "uăp", "uăt", "uâc", "uât",
#             "ue", "uen", "ueo", "uet", "uên", "uêt", "uêu",

#             # ư
#             "ư", "ưa", "ưc", "ưi",
#             "ưng", "ươc", "ươi",
#             "ươm", "ươn", "ương",
#             "ươp", "ươt", "ươu",
#             "ưt", "ưu",
#             # y
#             "y", "yêm", "yên",
#             "yêng", "yêt", "yêu"
#         ]
#         codas = ['ng', 'nh', 'ch', 'u', 'n', 'o', 'p', 'c', 'm', 'y', 'i', 't']
#         tones = ['<huyền>', '<sắc>', '<ngã>', '<hỏi>', '<nặng>']
#         phonemes = onsets + rhymes + codas + tones
#         self.phoneme2idx = {
#             phoneme: idx for idx, phoneme in enumerate(phonemes, start=4)
#         }
#         self.idx2phoneme = {idx: phoneme for phoneme, idx in self.phoneme2idx.items()}

#     def __len__(self):
#         return len(self.phoneme2idx) + 4

#     def encode_script(self, script: str):
#         script = script.lower()
#         pattern = f"[{re.escape(string.punctuation)}]"
#         script = re.sub(pattern, "", script)
#         words = script.split()
#         word_components = []
#         is_Vietnamese_words = []

#         for word in words:
#             if word == "quí":
#                 word = "quý"
#             is_Vietnamese_word, components = is_Vietnamese(word)
#             is_Vietnamese_words.append(is_Vietnamese_word)
#             if is_Vietnamese_word:
#                 word_components.append(components)
#             else:
#                 word_components.append(decompose_non_vietnamese_word(word))

#         phoneme_script = []
#         for ith in range(len(words)):
#             word_component = word_components[ith]
#             if is_Vietnamese_words[ith]:
#                 onset, medial, nucleus, coda, tone = word_component
#                 vowel = compose_word(None, medial, nucleus, coda, None)
#                 phoneme_script.extend([
#                     self.phoneme2idx[onset] if onset else self.blank_idx,
#                     self.phoneme2idx[vowel] if vowel else self.blank_idx,
#                     self.phoneme2idx[tone] if tone else self.blank_idx,
#                     self.blank_idx])
#             else:
#                 for char in word_component:
#                     onset, medial, nucleus, coda, tone = char
#                     vowel = compose_word(None, medial, nucleus, coda, None)
#                     phoneme_script.extend([
#                         self.phoneme2idx[onset] if onset else self.blank_idx,
#                         self.phoneme2idx[vowel] if vowel else self.blank_idx,
#                         self.phoneme2idx[tone] if tone else self.blank_idx,
#                         self.blank_idx])

#         vec = torch.tensor(phoneme_script[:-1]).long()  # remove the last blank token
#         return vec

#     def decode_script(self, tensor_script: torch.Tensor):
#         '''
#             tensorscript: (1, seq_len)
#         '''
#         # remove duplicated token
#         tensor_script = tensor_script.squeeze(0).long().tolist()
#         script = [self.idx2phoneme[idx] for idx in tensor_script]
#         script = ' '.join([k for k, _ in itertools.groupby(script)])

In [4]:
import torch.nn as nn
import torchaudio
from torchaudio import transforms
from torch.utils.data import Dataset
import json
import os
import torch.nn.functional as F


class MelSpectrogram(nn.Module):
    def __init__(self, sample_rate=16000, n_mels=64, win_length=400, hop_length=512, n_ffts=1024):
        # sample_rate: tần suất lấy mẫu (16000 điểm/s)
        # n_mels: chiều cao của Mel Spectrogram?
        # win_length: (nếu chiều dài của data là 800 => 2 window)
        # hop_length: bước nhảy (stride)
        # n_ffts: chiều dài mỗi Time-Section

        super(MelSpectrogram, self).__init__()

        self.transform = transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_mels=n_mels, win_length=win_length,
            hop_length=hop_length,
            n_fft=n_ffts
        )

    def forward(self, input):
        return self.transform(input)


class MFCC(nn.Module):
    def __init__(self, sample_rate=16000, n_mfcc=50, n_mels=64, win_length=400, hop_length=512, n_ffts=1024):
        super(MFCC, self).__init__()

        self.transform = transforms.MFCC(
            sample_rate=sample_rate,
            n_mfcc=n_mfcc,
            melkwargs={
                "n_mels": n_mels,
                "n_fft": n_ffts,
                "hop_length": hop_length
            }
        )

    def forward(self, input):
        return self.transform(input)


class CustomSpeech(Dataset):
    def __init__(self, audio_directory, data_directory, target_sample_rate, num_samples, transformation):
        super(CustomSpeech, self).__init__()

        self.audio_directory = audio_directory

        with open(data_directory, mode="r", encoding="utf-8") as file:
            data = json.load(file)
        self.data = list(data.items())

        if transformation == "Mel Spectrogram":
            self.transformation = MelSpectrogram()
        elif transformation == "MFCC":
            self.transformation = MFCC()

        self.target_sample_rate = target_sample_rate
        self.num_samples = num_samples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        audio_path = self._get_audio_path(index)
        audio_wave, sample_rate = torchaudio.load(audio_path)  # (num_channels, origin_num_samples), int
        resampled_signal = self._resample_if_necessary(audio_wave, sample_rate)  # (num_channels, resampled_num_samples)
        cut_signal = self._cut_down_if_necessary(resampled_signal)  # (num_channels, num_samples)
        padded_signal = self._pad_if_necessary(cut_signal)  # (num_channels, num_samples)
        mel = self.transformation(padded_signal)  # (num_channels, num_mels, num_samples)
        mel = mel.squeeze()  # (num_mels, num_samples)
        mel = mel.permute(1, 0)  # (num_samples, num_mels)

        len_signal = cut_signal.size()[1]

        return mel, len_signal

    def _get_audio_path(self, index):
        file = (self.data[index][1]['voice']).replace(".mp3", ".wav")
        return os.path.join(self.audio_directory, file)

    def _resample_if_necessary(self, audio_wave, sample_rate):
        if sample_rate != self.target_sample_rate:
            resampler = transforms.Resample(sample_rate, self.target_sample_rate)
            audio_wave = resampler(audio_wave)
        return audio_wave

    def _cut_down_if_necessary(self, resampled_signal):
        cut_signal = resampled_signal.clone()
        if resampled_signal.size()[1] > self.num_samples:
            cut_signal = cut_signal[:, :self.num_samples]
        return cut_signal

    def _pad_if_necessary(self, cut_signal):
        padded_signal = cut_signal.clone()
        if padded_signal.size()[1] < self.num_samples:
            signal_length = padded_signal.size()[1]
            num_missing_samples = self.num_samples - signal_length
            last_dim_padding = (0, num_missing_samples)
            padded_signal = F.pad(padded_signal, last_dim_padding)
        return padded_signal


from torch.utils.data import Dataset
import json
# from dataset.text.vocabulary import PhonemeVocabv2
# import re
# import string
import torch
# import numpy as np

class CustomText(Dataset):
    def __init__(self, data_directory, vocab, max_len):
        super(CustomText, self).__init__()

        with open(data_directory, mode="r", encoding="utf-8") as file:
            data = json.load(file)
        self.data = list(data.items())

        self.vocab = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        script = self.get_script(index)
        script = self.encode(script)
        script = self.cut_down_if_necessary(script)
        len_script = len(script)
        script = self.add_token(script)
        script = self.pad_if_necessary(script)
        return script, len_script

    def get_script(self, index):
        return self.data[index][1]["script"]

    # def _lower(self, script):
    #     lowered_script = script.lower()
    #     return lowered_script

    # def _remove_special_characters(self, lowered_script):
    #     pattern = f"[{re.escape(string.punctuation)}]"
    #     removed_script = re.sub(pattern, "", lowered_script)
    #     return removed_script

    # def _tokenize(self, removed_script):
    #     tokenized_script = self.vocab.tokenize(removed_script.split()[0])
    #     for word in removed_script.split()[1:]:
    #         tokenized_word = self.vocab.tokenize(word)
    #         tokenized_script = np.vstack((tokenized_script, tokenized_word))
    #     return tokenized_script

    def encode(self, script):
        return self.vocab.encode_script(script)

    def cut_down_if_necessary(self, script):
        if len(script) > self.max_len - 2:
            script = script[:self.max_len - 2]
        return script

    def add_token(self, script):
        return torch.cat((torch.tensor([self.vocab.bos_idx]), script, torch.tensor([self.vocab.eos_idx])))

    def pad_if_necessary(self, script):
        pad_value = torch.tensor([self.vocab.pad_idx])
        while len(script) < self.max_len:
            script = torch.cat((script, pad_value))
        return script

from torch.utils.data import Dataset
# from dataset.speech.speech import CustomSpeech
# from dataset.text.text import CustomText


class CustomDataset(Dataset):
    def __init__(self, audio_directory, data_directory, vocab, target_sample_rate, num_samples, transformation, max_len):
        super(CustomDataset, self).__init__()

        self.speech = CustomSpeech(audio_directory, data_directory, target_sample_rate, num_samples, transformation)
        self.script = CustomText(data_directory, vocab, max_len)

    def __len__(self):
        return len(self.speech)

    def __getitem__(self, index):
        return self.speech[index], self.script[index]


In [5]:
os.mkdir('ViVOS')
!gdown 1oV2v0RBHX_Rqvra0YUrgoBgND0QC64Pc -O ViVOS/train.json
!gdown 1obDaRybTfcOaGrl6mhAb5bMpMESs2lr5 -O ViVOS/test.json
!gdown 1JoUgZ6uGPb5_iZTDinjF5pRzUvhk-4-n -O ViVOS/voices.zip

Downloading...
From: https://drive.google.com/uc?id=1oV2v0RBHX_Rqvra0YUrgoBgND0QC64Pc
To: /content/ViVOS/train.json
100% 2.28M/2.28M [00:00<00:00, 140MB/s]
Downloading...
From: https://drive.google.com/uc?id=1obDaRybTfcOaGrl6mhAb5bMpMESs2lr5
To: /content/ViVOS/test.json
100% 135k/135k [00:00<00:00, 63.2MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1JoUgZ6uGPb5_iZTDinjF5pRzUvhk-4-n
From (redirected): https://drive.google.com/uc?id=1JoUgZ6uGPb5_iZTDinjF5pRzUvhk-4-n&confirm=t&uuid=d1f6a2ca-48db-4817-8627-ab1b4d345d1c
To: /content/ViVOS/voices.zip
100% 1.48G/1.48G [00:23<00:00, 62.1MB/s]


In [6]:
!unzip "/content/ViVOS/voices.zip" -d "/content/ViVOS/voices"

[1;30;43mKết quả truyền trực tuyến bị cắt bớt đến 5000 dòng cuối.[0m
  inflating: /content/ViVOS/voices/voices/VIVOSSPK42_299.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK35_257.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK40_149.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK07_R153.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK21_101.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK07_R147.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK21_115.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK11_R035.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK29_073.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK34_050.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK35_243.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK06_R153.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK31_101.wav  
  inflating: /content/ViVOS/voices/voices/VIVOSSPK10_R021.wav  
  inflating: /content/ViVOS/voices/voices/

In [7]:
paths = [
    # "Common-Voice/train.json",
    # "Common-Voice/dev.json",
    # "Common-Voice/test.json",
    "ViVOS/train.json",
    "ViVOS/test.json"
]

In [8]:
SAMPLE_RATE = 16000
NUM_SAMPLE = 55000  # Chọn tạm để test (nên chọn lại)
TRANSFORMATION = "MFCC"
VOCAB = Vocabulary(paths)
MAX_LEN = VOCAB.max_len

In [None]:
# common_voice_train = CustomDataset(
#     audio_directory="Vietnamese-Speech-to-Text-datasets//Common-Voice//converted voices",
#     data_directory="Vietnamese-Speech-to-Text-datasets//Common-Voice//train.json",
#     vocab=VOCAB,
#     target_sample_rate=SAMPLE_RATE,
#     num_samples = NUM_SAMPLE,
#     transformation=TRANSFORMATION,
#     max_len=MAX_LEN
# )

# common_voice_dev = CustomDataset(
#     audio_directory="Vietnamese-Speech-to-Text-datasets//Common-Voice//converted voices",
#     data_directory="Vietnamese-Speech-to-Text-datasets//Common-Voice//dev.json",
#     vocab=VOCAB,
#     target_sample_rate=SAMPLE_RATE,
#     num_samples = NUM_SAMPLE,
#     transformation=TRANSFORMATION,
#     max_len=MAX_LEN
# )

# common_voice_test = CustomDataset(
#     audio_directory="Vietnamese-Speech-to-Text-datasets//Common-Voice//converted voices",
#     data_directory="Vietnamese-Speech-to-Text-datasets//Common-Voice//test.json",
#     vocab=VOCAB,
#     target_sample_rate=SAMPLE_RATE,
#     num_samples = NUM_SAMPLE,
#     transformation=TRANSFORMATION,
#     max_len=MAX_LEN
# )

In [9]:
vivos_train = CustomDataset(
    audio_directory="ViVOS/voices/voices",
    data_directory="ViVOS/train.json",
    vocab=VOCAB,
    target_sample_rate=SAMPLE_RATE,
    num_samples = NUM_SAMPLE,
    transformation=TRANSFORMATION,
    max_len=MAX_LEN
)

vivos_test = CustomDataset(
    audio_directory="ViVOS/voices/voices",
    data_directory="ViVOS/test.json",
    vocab=VOCAB,
    target_sample_rate=SAMPLE_RATE,
    num_samples = NUM_SAMPLE,
    transformation=TRANSFORMATION,
    max_len=MAX_LEN
)

## Load dataloader

In [10]:
BATCH_SIZE = 64

In [None]:
# dataset = ConcatDataset([common_voice_train, common_voice_dev, common_voice_test, vivos_train, vivos_test])

In [None]:
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])
# train_dataset, dev_dataset = torch.utils.data.random_split(train_dataset, [0.8, 0.2])

In [None]:
# train_loader = DataLoader(
#     train_dataset,
#     batch_size=BATCH_SIZE,
#     shuffle=True
# )
# dev_loader = DataLoader(
#     dev_dataset,
#     batch_size=BATCH_SIZE,
#     shuffle=True
# )
# test_loader = DataLoader(
#     test_dataset,
#     batch_size=BATCH_SIZE,
#     shuffle=True
# )

In [11]:
train_loader = DataLoader(
    vivos_train,
    batch_size=BATCH_SIZE,
    shuffle=True
)
dev_loader = DataLoader(
    vivos_test,
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Train model

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm
from speechbrain.nnet.loss.transducer_loss import TransducerLoss

# Assuming device is set up as follows:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Acoustic model (Encoder)
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout_prob):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout_prob = dropout_prob

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            bias=True,
            batch_first=True,
            dropout=self.dropout_prob,
            bidirectional=True
        )

    def forward(self, inputs):  # inputs.size(): (batch_size, num_samples, num_mels)
        outputs, _ = self.lstm(inputs)  # outputs.size(): (batch_size, num_samples, hidden_size * 2)
        return outputs

# Language model (Decoder)
class Decoder(nn.Module):
    def __init__(self, vocab, hidden_size, num_layers, dropout_prob):
        super(Decoder, self).__init__()
        self.vocab = vocab
        self.embedding_size = len(self.vocab) - 4
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout_prob = dropout_prob
        self.start_symbol = self.vocab.bos_idx

        self.embedding = nn.Embedding(
            num_embeddings=len(self.vocab),
            embedding_dim=self.embedding_size,
            padding_idx=self.vocab.pad_idx
        )

        self.dropout = nn.Dropout(dropout_prob)

        self.lstm = nn.LSTM(
            input_size=self.embedding_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            bias=True,
            batch_first=True,
            dropout=self.dropout_prob,
            bidirectional=True
        )

    def forward(self, inputs):  # inputs.size(): (batch_size, text_len)
        batch_size = inputs.size()[0]
        text_len = inputs.size()[1]
        hidden_state = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size, device=inputs.device)
        cell = torch.zeros(self.num_layers * 2, batch_size, self.hidden_size, device=inputs.device)
        outputs = []
        for i in range(text_len):
            if i == 0:
                decoder_input = (torch.tensor([self.start_symbol] * batch_size, device=inputs.device)).unsqueeze(-1)  # decoder_input.size(): (batch_size, 1)
            else:
                decoder_input = inputs[:, i].unsqueeze(-1)  # decoder_input.size(): (batch_size, 1)
            output, hidden_state, cell = self.forward_one_step(decoder_input, hidden_state, cell)
            outputs.append(output)
        outputs = torch.stack(outputs, dim=1)  # outputs.size(): (batch_size, text_len + 1, hidden_size * 2)
        outputs = outputs.squeeze(2)

        return outputs

    def forward_one_step(
            self,
            input,  # input.size(): (batch_size, 1)
            previous_hidden_state,  # hidden_state.size(): num_layers * 2, batch_size, hidden_size
            previos_cell
        ):
        embeddeding = (self.dropout(self.embedding(input)))  # embeddeding.size(): (batch_size, 1, embedding_size)
        output, (hidden_state, cell) = self.lstm(embeddeding, (previous_hidden_state, previos_cell))
        # output.size(): (batch_size, 1, 2 * hidden_size)
        # hidden_state.size(): (2 * num_layers, batch_size, hidden_size)

        return output, hidden_state, cell

# Joiner
class Joiner(nn.Module):
    MODES = {
        "multiplicative": lambda outputs_encoder, outputs_decoder: outputs_encoder * outputs_decoder,
        "additive": lambda outputs_encoder, outputs_decoder: outputs_encoder + outputs_decoder,
        # "cat": lambda outputs_encoder, outputs_decoder: torch.cat((outputs_encoder, outputs_decoder), dim=1)
    }

    def __init__(self, hidden_size, vocab, mode):
        super(Joiner, self).__init__()

        self.vocab = vocab
        self.join = self.MODES[mode]

        self.linear = nn.Linear(
            in_features=hidden_size * 2,
            out_features=len(self.vocab)
        )

    def forward(
            self,
            outputs_encoder,  # outputs_encoder.size(): (batch_size, num_samples, hidden_size * 2)
            outputs_decoder,  # outputs_decoder.size(): (batch_size, text_len + 1, hidden_size * 2)
        ):
        outputs_encoder = outputs_encoder.unsqueeze(2)  # outputs_encoder.size(): (batch_size, num_samples, 1, hidden_size * 2)
        outputs_decoder = outputs_decoder.unsqueeze(1)  # outputs_decoder.size(): (batch_size, 1, text_len + 1, hidden_size * 2)

        outputs = self.join(outputs_encoder, outputs_decoder)  # outputs.size(): (batch_size, num_samples, text_len + 1, hidden_size * 2)

        outputs = self.linear(outputs)  # outputs.size(): (batch_size, num_samples, text_len + 1, len_vocab)

        outputs = F.softmax(outputs, dim=-1)

        return outputs

# RNNT Model
class RNNT(nn.Module):
    def __init__(self, vocab, input_size, device, hidden_size = 512, num_layers_encoder = 5, num_layers_decoder = 2, dropout_prob = 0.2, mode = "additive"):
        super(RNNT, self).__init__()
        self.vocab = vocab
        self.device = device

        self.encoder = Encoder(
            input_size,
            hidden_size = hidden_size,
            num_layers = num_layers_encoder,
            dropout_prob = dropout_prob
        ).to(self.device)

        self.decoder = Decoder(
            vocab=self.vocab,
            hidden_size = hidden_size,
            num_layers = num_layers_decoder,
            dropout_prob = dropout_prob
        ).to(self.device)

        self.joiner = Joiner(
            hidden_size=hidden_size,
            vocab=self.vocab,
            mode=mode
        ).to(self.device)

    def compute_loss(self, inputs, signal_len, text_len, targets):
        outputs_encoder = self.encoder(inputs)
        outputs_decoder = self.decoder.forward(targets)
        outputs_joiner = self.joiner(outputs_encoder, outputs_decoder)

        transducer_loss = TransducerLoss(self.vocab.blank_idx)
        loss = transducer_loss(outputs_joiner, targets, signal_len, text_len)
        return loss

    def greedy_decode(self, signals, signal_lens, max_len):
        outputs_batch = []
        batch_size = signals.size()[0]
        outputs_encoder = self.encoder(signals).to(self.device)
        for b in range(batch_size):
            t = 0
            u = 0
            outputs = [self.decoder.start_symbol]
            hidden_state = self.decoder.initial_state.unsqueeze(0)
            while t < signal_lens[b] and u < max_len:
                decoder_input = torch.tensor([outputs[-1]], device=self.device)
                output, hidden_state = self.decoder.forward_one_step(decoder_input, hidden_state)
                feature_t = outputs_encoder[b, t]
                output_joiner = self.joiner.forward(feature_t, output)
                argmax = output_joiner.max(-1)[1].item()
                if argmax == self.vocab.blank_idx:
                    t += 1
                else:  # argmax == a label
                    u += 1
                    outputs.append(argmax)
            outputs_batch.append(outputs[1:-1])
        return outputs_batch

    # def forward(self, inputs, max_len): => def decode?
    #     batch_size, num_samples, num_mels = inputs.size()

    #     outputs_encoder = self.encoder(inputs)  # outputs_encoder.size(): (batch_size, num_samples, hidden_size * 2)
    #     hidden_state, cell = self.decoder.init_hidden_state(batch_size)

    #     item = self._init_sos(batch_size)  # item.size(): (batch_size, 1, 3)
    #     counter = self._init_counter(batch_size)  # counter.size(): (batch_size)

    #     counter_ceil = num_samples - 1
    #     # term_state = torch.zeros(batch_size)
    #     t = 0

    #     while True:
    #         t += 1

    #         output_encoder = (outputs_encoder[range(batch_size), counter, :]).unsqueeze(1)  # signal.size(): (batch_size, 1, hidden_size * 2)
    #         prob_consonant, prob_vowel, prob_tone, hidden_state, cell = self._predict_next(item, output_encoder, hidden_state, cell)
    #         # prob_consonant.size(): (batch_size, num_consonants)
    #         # prob_vowel.size(): (batch_size, num_vowels)
    #         # prob_tone.size(): (batch_size, num_tones)
    #         # hidden_size.size(): (num_layers * 2, batch_size, hidden_size)
    #         # cell.size(): (num_layers * 2, batch_size, hidden_size)

    #         # prob_words = torch.cat((prob_consonant, prob_vowel, prob_tone), dim=1)
    #         # prob_words = prob_words.unsqueeze(1)


    #         predicted_word = torch.cat(
    #             (
    #                 torch.argmax(prob_consonant, dim=-1).unsqueeze(-1),  # .size(): (batch_size, 1)
    #                 torch.argmax(prob_vowel, dim=-1).unsqueeze(-1),  # .size(): (batch_size, 1)
    #                 torch.argmax(prob_tone, dim=-1).unsqueeze(-1)  # .size(): (batch_size, 1)
    #             ),
    #             dim=1
    #         )  # predicted_word.size(): (batch_size, 3)
    #         predicted_word = predicted_word.unsqueeze(1) # predicted_word.size(): (batch_size, 1, 3)

    #         if t == 1:
    #             # results = prob_words
    #             predictions = predicted_word  # predictions.size(): (batch_size, 1, 3)
    #         else:
    #             # results = torch.cat([results, prob_words], dim=1)
    #             predictions = torch.cat([predictions, predicted_word], dim=1)  # predictions.size(): (batch_size, t, 3)

    #         is_blank = self._is_blank(batch_size, predicted_word)  # is_blank.size(): (batch_size, 1)
    #         item = self._update_item(is_blank, item, predicted_word)  # item.size(): (batch_size, 1, 3)

    #         counter, update_mask = self._update(is_blank, counter, counter_ceil)

    #         if (update_mask.sum().item() == batch_size) or (t == max_len):  # counter tới max của tất cả trong batch hoặc đạt max len
    #             break

    #     return predictions  # predictions.size(): (batch_size, max_len, 3)

    # def _init_sos(self, batch_size):
    #     return torch.LongTensor(np.array([self.vocab.get_index("<sos>")] * batch_size))  # (batch_size, 1, 3)

    # def _init_counter(self, batch_size):
    #     return np.zeros(batch_size, dtype=int)

    # def _predict_next(
    #         self,
    #         item,  # item.size(): (batch_size, 1, 3)
    #         output_encoder,  # output_encoder.size(): (batch_size * num_samples, hidden_size * 2)
    #         hidden_state,
    #         cell
    #     ):
    #     output, hidden_state, cell = self.decoder(item, hidden_state, cell)
    #     prob_consonant, prob_vowel, prob_tone = self.joiner(output_encoder, output)
    #     return prob_consonant, prob_vowel, prob_tone, hidden_state, cell

    # def _is_blank(self, batch_size, predicted_word):
    #     is_blank = torch.zeros((batch_size, 1), dtype=bool)
    #     for i, predicted_word in enumerate(predicted_word):
    #         is_blank[i] = torch.equal(predicted_word, torch.LongTensor(self.vocab.get_index("<blank>")))
    #     return is_blank

    # def _update_item(self, is_blank, item, predicted_word):
    #     return ((is_blank * item.squeeze()) + (~is_blank * predicted_word.squeeze())).unsqueeze(1)

    # def _update(self, is_blank, counter, counter_ceil):
    #     counter = counter + is_blank.squeeze().numpy()
    #     counter, update_mask = self._clip_counter(counter, counter_ceil)
    #     # term_state = self._update_termination_state(term_state, update_mask, t)
    #     return counter, update_mask

    # def _clip_counter(self, counter, counter_ceil):
    #     update_mask = counter >= counter_ceil  # update_mask.shape: (batch_size,) (bool)
    #     upper_bounded = update_mask * counter_ceil  # upper_bounded.shape: (batch_size,)
    #     kept_counter = (counter < counter_ceil) * counter  # kept_counter.shape: (batch_size,)
    #     return upper_bounded + kept_counter, update_mask

    # # def _update_termination_state(self, term_state, update_mask, t):
    # #     is_unended = term_state == 0
    # #     to_update = is_unended & update_mask
    # #     return term_state + to_update * t


In [13]:
from torch.optim import Adam
from tqdm import tqdm
import torch


class Trainer:
    def __init__(self, model, lr, vocab, device):
        self.model = model
        self.lr = lr
        self.optimizer = Adam(model.parameters(), lr=self.lr)
        self.vocab = vocab
        self.device = device

    def train(self, train_loader, max_len, print_interval=20):
        train_loss = 0
        num_samples = 0
        self.model.train()
        for index, batch in enumerate(tqdm(train_loader)):
            (signals, signal_lens), (scripts, text_lens) = batch
            signals = signals.to(self.device)
            signal_lens = signal_lens.to(self.device)
            scripts = scripts.to(self.device)
            text_lens = text_lens.to(self.device)
            batch_size = signals.size()[0]
            num_samples += batch_size
            loss = self.model.compute_loss(signals, signal_lens, text_lens, scripts)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item() * batch_size

            if index % print_interval == 0:
                self.model.eval()
                guesses = self.model.greedy_decode(signals, signal_lens, max_len)
                print("\n")
                for b in range(2):
                    print("guess:", self.vocab.decode_script(guesses[b]))
                    print("truth:", self.vocab.decode_script(scripts[b, :text_lens[b]]))

        train_loss /= num_samples
        return train_loss

    def test(self, dev_loader, max_len, print_interval=10):
        test_loss = 0
        num_samples = 0
        self.model.eval()
        with torch.no_grad():
            for idx, batch in tqdm(dev_loader):
                (signals, signal_lens), (scripts, text_lens) = batch
                signals = signals.to(self.device)
                signal_lens = signal_lens.to(self.device)
                scripts = scripts.to(self.device)
                text_lens = text_lens.to(self.device)
                batch_size = signals.size()[0]
                num_samples += batch_size
                loss = self.model.compute_loss(signals, signal_lens, text_lens, scripts)
                test_loss += loss.item() * batch_size
                if idx % print_interval == 0:
                    print("\n")
                    print("guess:", self.vocab.decode_script(self.model.greedy_decode(signals, signal_lens, max_len)[0]))
                    print("truth:", self.vocab.decode_script(scripts[0, :text_lens[0]]))
                    print("")
        test_loss /= num_samples
        return test_loss

In [14]:
rnnt = RNNT(
    vocab=VOCAB,
    input_size=50,
    device=torch.device('cpu')
)

In [15]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = Trainer(model=rnnt, lr=0.0003, vocab=VOCAB, device=torch.device('cpu'))
NUM_EPOCHS = 10
train_losses = []
test_losses = []

for epoch in range(NUM_EPOCHS):
    train_loss = trainer.train(train_loader=train_loader, max_len=MAX_LEN)
    test_loss = trainer.test(dev_loader=dev_loader, max_len=MAX_LEN)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    print("Epoch %d: train loss = %f, test loss = %f" % (epoch, train_loss, test_loss))

  0%|          | 0/183 [01:10<?, ?it/s]


ValueError: Found inputs tensors to be on [device(type='cpu'), device(type='cpu'), device(type='cpu'), device(type='cpu')] while needed to be on a 'cuda' device to use the transducer loss.