In [1]:
import torch
import torch.nn as nn
from torch import Tensor

class Loss(nn.Module):
    def __init__(self, phi_idx: int, device='cpu') -> None:
        super().__init__()
        self.phi_idx = phi_idx
        self.device = device

    def forward(
            self,
            probs: Tensor,
            target: Tensor,
            target_lengths: Tensor
            ) -> Tensor:
        target_lengths = target_lengths.to(self.device)
        batch_size, max_length, *_ = probs.shape
        n_chars = target_lengths.max().item()
        n_nulls = max_length - n_chars
        
        scores = self.get_score_matrix(batch_size, n_chars, n_nulls)
        scores = scores.to(self.device)

        for c in range(n_chars + 1):
            for p in range(n_nulls + 1):
                if c == 0 and p == 0:
                    continue
                scores = self.update_scores(scores, probs, target, p, c)
        return self.calc_loss(scores, target_lengths)

    def calc_loss(self, scores: Tensor, target_lengths: Tensor) -> Tensor:
        loss = torch.diagonal(torch.index_select(
            scores[:, :, -1], dim=1, index=target_lengths
            ))
        loss = -1 * loss
        return loss.mean()

    def get_score_matrix(
            self, batch_size: int, n_chars: int, n_nulls: int
            ) -> Tensor:
        return torch.zeros(batch_size, n_chars + 1, n_nulls + 1)

    def update_scores(
            self, scores: Tensor, probs: Tensor, target: Tensor, p: int, c: int
            ) -> Tensor:
        if p == 0:
            chars_probs = self.get_chars_probs(probs, target, c, p)
            scores[:, c, p] = chars_probs + scores[:, c - 1, p]
            return scores
        elif c == 0:
            phi_probs = self.get_phi_probs(probs, c, p)
            scores[:, c, p] = phi_probs + scores[:, c, p - 1]
            return scores
        chars_probs = self.get_chars_probs(probs, target, c, p)
        phi_probs = self.get_phi_probs(probs, c, p)
        scores[:, c, p] = torch.logsumexp(
            torch.stack(
                [scores[:, c, p - 1] + self.log(phi_probs),
                scores[:, c - 1, p] + self.log(chars_probs)]
            ), dim=0)
        return scores

    def get_phi_probs(self, probs: Tensor, c: int, p: int) -> Tensor:
        return probs[:, c + p - 1, self.phi_idx]

    def get_chars_probs(
            self, probs: Tensor, target: Tensor, c: int, p: int
            ) -> Tensor:
        all_seqs = probs[:, p + c - 1]
        result = torch.index_select(all_seqs, dim=-1, index=target[:, c - 1])
        return torch.diagonal(result)

In [2]:
from datetime import datetime

def get_formated_date() -> str:
    t = datetime.now()
    return f'{t.year}{t.month}{t.day}-{t.hour}{t.minute}{t.second}'

In [None]:
from pathlib import Path
from networks.model import Model
from data.tokenizer import CharTokenizer, BaseTokenizer
from torch.optim import Optimizer
from data.data import AudioPipeline, DataLoader, TextPipeline
from typing import Callable, Union
from torch.nn import Module
from functools import wraps
import torch
import os

device = "cpu"

class Trainer:
    __train_loss_key = 'train_loss'
    __test_loss_key = 'test_loss'

    def __init__(
            self,
            criterion: Module,
            optimizer: Optimizer,
            model: Module,
            device: str,
            train_loader: DataLoader,
            test_loader: DataLoader,
            epochs: int,
            length_multiplier: float
    ) -> None:
        self.criterion = criterion
        self.optimizer = optimizer
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.epochs = epochs
        self.step_history = dict()
        self.history = dict()
        self.length_multiplier = length_multiplier

    def fit(self):
        for _ in range(self.epochs):
            self.train()
            # self.test()
            self.print_results()

    def set_train_mode(self) -> None:
        self.model = self.model.train()

    def set_test_mode(self) -> None:
        self.model = self.model.eval()

    def print_results(self):
        result = ''
        for key, value in self.history.items():
            result += f'{key}: {str(value[-1])}, '
        print(result[:-2])

    def test(self):
        total_loss = 0
        self.set_test_mode()
        for x, y, lengths in self.test_loader:
            x = x.to(self.device)
            y = y.to(self.device)
            max_len = int(x.shape[0] * self.length_multiplier)
            x = torch.squeeze(x, dim=1)
            result = self.model(x, max_len)
            result = result.reshape(-1, result.shape[-1])
            y = y.reshape(-1)
            y = torch.squeeze(y)
            loss = self.criterion(torch.squeeze(result), y)
            total_loss += loss.item()
        total_loss /= len(self.test_loader)
        if self.__test_loss_key in self.history:
            self.history[self.__test_loss_key].append(total_loss)
        else:
            self.history[self.__test_loss_key] = [total_loss]

    def train(self):
        total_loss = 0
        self.set_train_mode()
        for (x, y, length) in self.train_loader:
            x = x.to(self.device)
            y = y.to(self.device)
            max_len = int(x.shape[1] * self.length_multiplier)
            x = torch.squeeze(x, dim=1)
            self.optimizer.zero_grad()
            probs, term_state = self.model(x, max_len)
            loss = self.criterion(probs, y, length)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
        total_loss /= len(self.train_loader)
        if self.__train_loss_key in self.history:
            self.history[self.__train_loss_key].append(total_loss)
        else:
            self.history[self.__train_loss_key] = [total_loss]


def get_model_args(
        vocab_size: int,
        pad_idx: int,
        phi_idx: int,
        sos_idx: int
) -> dict:
    prednet_params = {
        'vocab_size': vocab_size,
        'emb_dim': 128,
        'pad_idx': pad_idx,
        'hidden_size': 256,
        'n_layers': 2,
        'dropout': 0.2
    }

    transnet_params = {
        'input_size': 512,
        'hidden_size': 512,
        'n_layers': 3,
        'dropout': 0.3,
        'is_bidirectional': True
    }

    joinnet_params = {
        'input_size': 512,
        'vocab_size': vocab_size,
        'mode': 'multiplicative'
    }

    return {
        'prednet_params': prednet_params,
        'transnet_params': transnet_params,
        'joinnet_params': joinnet_params,
        'device': device,
        'phi_idx': phi_idx,
        'pad_idx': pad_idx,
        'sos_idx': sos_idx
    }


def load_model(vocab_size: int, *args, **kwargs) -> Module:
    return Model(**get_model_args(vocab_size, *args, **kwargs))


def get_tokenizer():
    tokenizer = CharTokenizer()
    tokenizer = tokenizer.add_phi_token().add_pad_token()
    tokenizer = tokenizer.add_sos_token().add_eos_token()
    with open('vocab.txt', 'r') as f:
        vocab = f.read().split('\n')
    tokenizer.set_tokenizer(vocab)
    tokenizer.save_tokenizer('tokenizer.json')
    return tokenizer


def get_data_loader(
        file_path: Union[str, Path],
        tokenizer: BaseTokenizer
):
    audio_pipeline = AudioPipeline()
    text_pipeline = TextPipeline()
    batch_size = 8
    return DataLoader(
        file_path,
        text_pipeline,
        audio_pipeline,
        tokenizer,
        batch_size

    )


def get_trainer(batch_size, file_text, training_file_path, testing_file_path):
    tokenizer = get_tokenizer()
    phi_idx = tokenizer.special_tokens.phi_id
    pad_idx = tokenizer.special_tokens.pad_id
    sos_idx = tokenizer.special_tokens.sos_id
    vocab_size = tokenizer.vocab_size
    train_loader = get_data_loader(
        training_file_path,
        tokenizer
    )
    test_loader = get_data_loader(
        testing_file_path,
        tokenizer
    )
    criterion = Loss(phi_idx)
    model = load_model(
        vocab_size,
        pad_idx=pad_idx,
        phi_idx=phi_idx,
        sos_idx=sos_idx
    )
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.0001,
        momentum=0.9
    )
    length_multiplier = 1.5
    return Trainer(
        criterion=criterion,
        optimizer=optimizer,
        model=model,
        device=device,
        train_loader=train_loader,
        test_loader=test_loader,
        epochs=50,
        length_multiplier=length_multiplier
    )


trainer = get_trainer()
trainer.fit()
