In [1]:
import torch
import torchaudio

from torch.optim import AdamW
from torch.utils.data import DataLoader

from torch import nn

import pandas as pd
import numpy as np

from tqdm.auto import tqdm

from sklearn.model_selection import KFold, StratifiedKFold

from transformers import PreTrainedTokenizerFast

from transformers import get_cosine_schedule_with_warmup

from modules.model import Transformer, train_epoch, eval_epoch
from modules.dataset import AudioDataset
#from modules.tokenizer import tokenize

import random
import os

torch.cuda.is_available()

True

In [2]:
path_to_tokenizer = './tokenizer.json'
path_to_data = './audio_dataset/'

data = pd.read_csv(os.path.join(path_to_data,'df.csv'), usecols=['text','status','path','rate','duration','frames'])
data = data[data.status=='APPROVED'].reset_index(drop=True)
del data['status']
data.text = data.text.apply(lambda x: "".join([char for char in x if char.isalpha() or char==' ']).lower())
data.duration.max()

train_data = data.iloc[:300]
valid_data = data.iloc[70000:]

In [3]:
# import matplotlib.pyplot as plt
# import seaborn as sns
# sns.displot(data.text.str.len())
# plt.show()dd
# sns.displot(data.duration)
# plt.show()

In [4]:
np.percentile(data.text.str.len(), 99.5)

135.0

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = PreTrainedTokenizerFast(tokenizer_file=path_to_tokenizer, 
                                    padding_side ='right',
                                    bos_token = '[SOS]',
                                    eos_token = '[EOS]',
                                    pad_token = '[PAD]',
                                    unk_token = '[UNK]',
                                    mask_token = '[MASK]')


train_dataset = AudioDataset(train_data, path_to_data, tokenizer, n_fft=1024, n_mels=64, center=True, max_tokenized_length=100, max_audio_len=25, sr=16000)
valid_dataset = AudioDataset(valid_data, path_to_data, tokenizer, n_fft=1024, n_mels=64, center=True, max_tokenized_length=100, max_audio_len=25, sr=16000)
model = Transformer(vocab_size=len(tokenizer),
                    n_mels=64,
                    enc_seq_len=25, 
                    dec_seq_len=100,
                    hidden_dim=16, 
                    enc_num_layers=2, 
                    dec_num_layers=2, 
                    num_heads=3, 
                    ff_dim=128, 
                    r_dim=100, 
                    device=device,
                    dropout=0.0, 
                    sr=16000, 
                    n_fft=1024)

In [6]:
train_dataset[0]['attention_mask'].squeeze().sum()

tensor(51)

In [7]:
len(tokenizer.encode(train_dataset[0]['text']))

51

n_fft=1024, win_lenght=1024, hop_lenght=256, n_mels=64, center=True):

In [8]:
torch.cuda.empty_cache()

In [9]:
from torchmetrics.functional import word_error_rate
from torchmetrics.functional.classification import multiclass_accuracy

In [10]:
def cross_validation(model, 
                     dataset, 
                     loss_function,
                     strat_array=None,
                     device=torch.device("cuda"),
                     random_state: int=69, 
                     shuffle: bool=True, 
                     n_folds: int=10, 
                     epochs: int=10, 
                     lr: float=1e-6,
                     start_fold: int=0, 
                     batch_size: int=4,
                     iters_to_accumulate=None,
                     n_accumulated_grads: int = 0):
    random.seed(random_state),
    np.random.seed(random_state)
    torch.manual_seed(random_state)
    torch.cuda.manual_seed_all(random_state)
    
    loss_function.to(device)
    if strat_array:
        kfold = StratifiedKFold(n_folds, shuffle=shuffle, random_state=random_state)
        split = kfold.split(dataset, strat_array)
    else: 
        kfold = KFold(n_folds, shuffle=shuffle, random_state=random_state)
        split = kfold.split(dataset)

    for fold, (train_ids, eval_ids) in enumerate(split):
        if fold >= start_fold:
            print(f'FOLD {fold}')
            print('--------------------------------')

            optimizer = AdamW(
            model.parameters(),
            lr = 1e-3,
        )

            train_subsampler = torch.utils.data.Subset(dataset,  train_ids)
            train_loader = torch.utils.data.DataLoader(
                          train_subsampler, 
                          batch_size=batch_size,
                          shuffle=shuffle, drop_last=True)

            eval_subsampler = torch.utils.data.Subset(dataset,  eval_ids)
            eval_loader = torch.utils.data.DataLoader(
                          eval_subsampler,
                          batch_size=batch_size,
                          shuffle=shuffle, drop_last=True)
            
            total_steps = len(train_loader) * epochs 

            scheduler = get_cosine_schedule_with_warmup(optimizer, 
                                                    num_warmup_steps = 0, # Default value in run_glue.py
                                                    num_training_steps = total_steps)


            for epoch_i in range(epochs):
                train_metrics, t_preds = train_epoch(model, train_loader, dataset.tokenizer, loss_function, optimizer, scheduler, device)
                eval_metrics, preds = eval_epoch(model, eval_loader, dataset.tokenizer, loss_function, device)
                print(f"EPOCH: {epoch_i}")
                print(train_metrics)
                print(eval_metrics)
                print(t_preds)
                print(preds)
        break

In [11]:
tokenizer.pad_token_id

4

In [12]:
string = tokenizer.encode("я люблю дашу")
tokenizer.decode(string)

'[SOS]я люблю дашу [EOS]'

In [14]:
cross_validation(model = model,
                 dataset=train_dataset, 
                 loss_function=nn.CrossEntropyLoss(ignore_index=4), 
                 device=torch.device("cuda"),
                 random_state=69,
                 shuffle=True,
                 batch_size=16)

FOLD 0
--------------------------------


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 0
{'Train Loss': 3.282787002623081, 'Train WER': 1.0158402919769287, 'Train Accuracy': 0.0}
{'Val Loss': 8.24382209777832, 'Val WER': 1.0, 'Val Accuracy': 0.0}
кидтедатедви крполе нам стесврпоеслапокстеросвлвсвкоилпоовкоех наи им са щвм тееом



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 1
{'Train Loss': 1.3623016923666, 'Train WER': 0.9034352898597717, 'Train Accuracy': 0.0}
{'Val Loss': 9.292266845703125, 'Val WER': 1.0, 'Val Accuracy': 0.0}
т лее появилисленовые рорсонажи принявшиелиожесточенвыкрошить стекла и двери нама правих льства



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 2
{'Train Loss': 0.6556443609297276, 'Train WER': 0.34064731001853943, 'Train Accuracy': 0.12109375}
{'Val Loss': 10.979683876037598, 'Val WER': 1.0, 'Val Accuracy': 0.0}
эрик феыпоборолся за шайбу у борта а потом откинул ее свободному маркусу юханссону



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 3
{'Train Loss': 0.27032983768731356, 'Train WER': 0.09256940335035324, 'Train Accuracy': 0.5078125}
{'Val Loss': 10.5862398147583, 'Val WER': 1.0, 'Val Accuracy': 0.0}
забываешь что мы работаем в масках и наших лимоне видно



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 4
{'Train Loss': 0.12142773857340217, 'Train WER': 0.04437323287129402, 'Train Accuracy': 0.6875}
{'Val Loss': 11.389775276184082, 'Val WER': 1.0, 'Val Accuracy': 0.0}
почитай камасутру посгибай женщину в разных плоскостях и ваши отношения выйдут на новую высоту



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 5
{'Train Loss': 0.06546713318675756, 'Train WER': 0.02690771035850048, 'Train Accuracy': 0.7578125}
{'Val Loss': 12.856123924255371, 'Val WER': 1.0, 'Val Accuracy': 0.0}
корнийчук также опроверг слухи оoисключении из партии народного депутата елены шустик



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 6
{'Train Loss': 0.07942213444039226, 'Train WER': 0.033004216849803925, 'Train Accuracy': 0.78125}
{'Val Loss': 9.98002815246582, 'Val WER': 1.0, 'Val Accuracy': 0.0}
вспышка сальмонеллеза в сша может стать причиной для очередного запрета ввоза мяса птицы в россию



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 7
{'Train Loss': 0.035873730783350766, 'Train WER': 0.016270743682980537, 'Train Accuracy': 0.8046875}
{'Val Loss': 13.369752883911133, 'Val WER': 1.0, 'Val Accuracy': 0.0}
в россии в этом плане всетаки наблюдается некоторая стабильность отметила лилия шибанова



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH: 8
{'Train Loss': 0.029486054321751, 'Train WER': 0.012669975869357586, 'Train Accuracy': 0.8203125}
{'Val Loss': 12.50363826751709, 'Val WER': 1.0, 'Val Accuracy': 0.0}
самолет которым управлял опытный и хорошо подготовленный летчик выполнял обычный тренировочный полет



  0%|          | 0/16 [00:00<?, ?it/s]


KeyboardInterrupt

