In [80]:
from typing import Iterable, Tuple, List
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece
import omegaconf
import pytorch_lightning as pl
import pandas as pd
from tqdm.auto import tqdm

from src.models import ConformerLAS, ConformerCTC
from src.metrics import WER

In [9]:
def init_model(model: pl.LightningModule, ckpt_path: str) -> pl.LightningModule:
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(ckpt)
    model.eval()
    model.freeze()
    return model


def compute_wer(refs: Iterable[str], hyps: Iterable[str]) -> float:
    wer = WER()
    wer.update(refs, hyps)
    return wer.compute()[0].item()


class GreedyDecoderLAS:
    def __init__(self, model: ConformerLAS, max_steps=20):
        self.model = model
        self.max_steps = max_steps

    def __call__(self, encoded: torch.Tensor) -> str:
        
        tokens = [self.model.decoder.tokenizer.bos_id()]

        for _ in range(self.max_steps):
            
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )
        
            best_next_token = distribution[0, -1].argmax()
            
            if best_next_token == self.model.decoder.tokenizer.eos_id():
                break

            tokens.append(best_next_token.item())
        
        return self.model.decoder.tokenizer.decode(tokens)

# Single Model

In [10]:
dataset = 'test_opus/farfield/manifest.jsonl'

## LAS

In [11]:
conf = omegaconf.OmegaConf.load("./conf/conformer_las.yaml")
conf.val_dataloader.dataset.manifest_name = dataset
conf.model.decoder.tokenizer = "./data/tokenizer/bpe_1024_bos_eos.model"

conformer_las = init_model(
    model=ConformerLAS(conf=conf),
    ckpt_path="./data/conformer_las_2epochs.ckpt"
)

In [12]:
las_decoder = GreedyDecoderLAS(conformer_las)

refs, hyps_las = [], []

for batch in tqdm(conformer_las.val_dataloader()):

    features, features_len, targets, target_len = batch

    encoded, encoded_len = conformer_las(features, features_len)
    
    for i in range(features.shape[0]):

        encoder_states = encoded[
            [i],
            :encoded_len[i],
            :
        ]

        ref_tokens = targets[i, :target_len[i]].tolist()

        refs.append(
            conformer_las.decoder.tokenizer.decode(ref_tokens)
        )
        hyps_las.append(
            las_decoder(encoder_states)
        )

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 479/479 [08:09<00:00,  1.02s/it]


In [13]:
compute_wer(refs, hyps_las)

0.42290276288986206

## CTC

In [None]:
# TODO: load models, estimate WER

In [55]:
def decode_ctc_hyps(model: ConformerCTC) -> Tuple[List[str], List[str]]:
    refs = []
    hyps_ctc = []
    for batch in tqdm(model.val_dataloader()):
        features, features_len, targets, target_len = batch
        _, encoded_len, preds = model(features, features_len)
        refs.extend(model.decoder.decode(targets, target_len))
        hyps_ctc.extend(model.decoder.decode(preds, encoded_len, unique_consecutive=True))
    return refs, hyps_ctc

In [56]:
conf = omegaconf.OmegaConf.load("./conf/conformer_ctc.yaml")
conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc = init_model(
    model=ConformerCTC(conf=conf),
    ckpt_path="./data/conformer_7epochs_state_dict.ckpt"
)

refs, hyps_ctc = decode_ctc_hyps(conformer_ctc)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 479/479 [04:42<00:00,  1.70it/s]


In [57]:
conf = omegaconf.OmegaConf.load("./conf/conformer_ctc_wide.yaml")
conf.val_dataloader.dataset.manifest_name = dataset

conformer_ctc_wide = init_model(
    model=ConformerCTC(conf=conf),
    ckpt_path="./data/conformer_wide_7epochs_state_dict.ckpt"
)

refs, hyps_ctc_wide = decode_ctc_hyps(conformer_ctc_wide)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 479/479 [04:48<00:00,  1.66it/s]


In [58]:
compute_wer(refs, hyps_ctc)

0.4247196316719055

In [59]:
compute_wer(refs, hyps_ctc_wide)

0.3700787425041199

# ROVER: Recognizer Output Voting Error Reduction — 5 points

* [A post-processing system to yield reduced word error rates: Recognizer Output Voting Error Reduction (ROVER)](https://ieeexplore.ieee.org/document/659110)
* [Improved ROVER using Language Model Information](https://www-tlp.limsi.fr/public/asr00_holger.pdf)

Alignment + Voting

![](./images/rover_table.png)

In [65]:
from crowdkit.aggregation.texts import ROVER

In [69]:
workers = []
tasks = []
texts = []
for i, (hyp_las, hyp_ctc, hyp_ctc_wide) in enumerate(zip(hyps_las, hyps_ctc, hyps_ctc_wide)):
    workers += ['las', 'ctc', 'ctc_wide']
    tasks += [i, i, i]
    texts += [hyp_las, hyp_ctc, hyp_ctc_wide]
    
aggregated = pd.DataFrame({
    'worker': workers,
    'task': tasks,
    'text': texts
})

In [70]:
tokenizer = lambda s: s.split(' ')
detokenizer = lambda tokens: ' '.join(tokens)
rover = ROVER(tokenizer, detokenizer)

In [72]:
hyps_rover = rover.fit_predict(aggregated).values

In [73]:
compute_wer(refs, hyps_rover)

0.36017656326293945

# MBR: Minimum Bayes Risk — 5 points


* [Minimum Bayes Risk Decoding and System
Combination Based on a Recursion for Edit Distance](https://danielpovey.com/files/csl11_consensus.pdf)
* [mbr-decoding blog-post](https://suzyahyah.github.io/bayesian%20inference/machine%20translation/2022/02/15/mbr-decoding.html)
* [Combination of end-to-end and hybrid models for speech recognition](http://www.interspeech2020.org/uploadfile/pdf/Tue-1-8-4.pdf)

![](./images/mbr_scheme.png)

In [74]:
class GreedyDecoderLASWithProbas(GreedyDecoderLAS):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def get_hyp_probability(self, encoded: torch.Tensor, hyp):
        tokens = [self.model.decoder.tokenizer.bos_id()]
        score = 1
        softmax = nn.Softmax()
        for i in range(1, len(hyp)):
            tokens_batch = torch.tensor(tokens).unsqueeze(0)
            att_mask = self.model.make_attention_mask(torch.tensor([tokens_batch.size(-1)]))
            
            distribution = self.model.decoder(
                encoded=encoded, encoded_pad_mask=None,
                target=tokens_batch, target_mask=att_mask, target_pad_mask=None
            )
            
            score *= softmax(distribution[0, -1])[hyp[i]]
            tokens.append(hyp[i])

        return score

In [75]:
las_decoder = GreedyDecoderLASWithProbas(conformer_las)
def get_hyp_prob_las(hyp, encoder_states, las_decoder=las_decoder, conformer_las=conformer_las):
    hyp_tokens = [conformer_las.decoder.tokenizer.bos_id()]
    hyp_tokens.extend(conformer_las.decoder.tokenizer.tokenize(hyp))
    hyp_tokens.append(conformer_las.decoder.tokenizer.eos_id())
    return las_decoder.get_hyp_probability(encoder_states, hyp_tokens)

In [83]:
labels = list(' абвгдежзийклмнопрстуфхцчшщъыьэюя')
token_to_idx = {token: idx for idx, token in enumerate(labels)}

def get_hyp_prob_ctc(hyp, model, logprobs, encoded_len):
    if '⁇' in hyp:
        return None
    hyp_tokens = torch.Tensor([int(token_to_idx[token]) for token in hyp]).long()
    target_len = len(hyp)
    loss = model.ctc_loss(
        logprobs, hyp_tokens, torch.Tensor(encoded_len), torch.tensor(target_len)
    )
    prob = torch.exp(-loss)
    return prob

In [85]:
import editdistance
def get_distance_matrix(hyps: List[str]):
    distance_matrix = np.zeros((len(hyps), len(hyps)))
    for i in range(len(hyps)):
        for j in range(i + 1, len(hyps)):
            distance = editdistance.eval(hyps[i], hyps[j])
            distance_matrix[i, j] = distance
            distance_matrix[j, i] = distance
    return distance_matrix

def get_mbr_hyp(hyps, scores, weights=None):
    if weights is None:
        weights = np.ones(len(hyps))
    hyp_scores = []
    distance_matrix = get_distance_matrix(hyps)
    for i in range(len(hyps)):
        tmp_scores = []
        for j in range(len(hyps)):
            probs_sum = np.sum(scores[j])
            tmp_score = np.dot(distance_matrix[i], scores[j]) / probs_sum
            tmp_scores.append(tmp_score)
        hyp_scores.append(np.dot(weights, tmp_scores))
    return hyps[np.argmin(hyp_scores)]

In [86]:
batch_size = 4
scores_las = []
scores_ctc = []
scores_ctc_wide = []
k = -1

for i, batch in tqdm(enumerate(conformer_las.val_dataloader()), total=len(conformer_las.val_dataloader())):
    features, features_len, targets, target_len = batch
    curr_hyp_las = hyps_las[batch_size * i: batch_size * i + batch_size]
    curr_hyp_ctc = hyps_ctc[batch_size * i: batch_size * i + batch_size]
    curr_hyp_ctc_wide = hyps_ctc_wide[batch_size * i: batch_size * i + batch_size]
    
    encoded_las, encoded_len = conformer_las(features, features_len)
    logprobs_ctc, encoded_len, preds_ctc = conformer_ctc(features, features_len)
    logprobs_ctc_wide, encoded_len, preds_ctc_wide = conformer_ctc_wide(features, features_len)

    for i in range(features.shape[0]):
        curr_scores_las = []
        curr_scores_ctc = []
        curr_scores_ctc_wide = []

        encoder_states_las = encoded_las[
            [i],
            :encoded_len[i],
            :
        ]
        
        curr_scores_las.append(get_hyp_prob_las(curr_hyp_las[i], encoder_states_las))
        curr_scores_las.append(get_hyp_prob_las(curr_hyp_ctc[i], encoder_states_las))
        curr_scores_las.append(get_hyp_prob_las(curr_hyp_ctc_wide[i], encoder_states_las))
        
        curr_scores_ctc.append(
            get_hyp_prob_ctc(curr_hyp_las[i], conformer_ctc, logprobs_ctc[i], encoded_len[i])
        )
        curr_scores_ctc.append(
            get_hyp_prob_ctc(curr_hyp_ctc[i], conformer_ctc, logprobs_ctc[i], encoded_len[i])
        )
        curr_scores_ctc.append(
            get_hyp_prob_ctc(curr_hyp_ctc_wide[i], conformer_ctc, logprobs_ctc[i], encoded_len[i])
        )
        
        curr_scores_ctc_wide.append(
            get_hyp_prob_ctc(curr_hyp_las[i], conformer_ctc_wide, logprobs_ctc_wide[i], encoded_len[i])
        )
        curr_scores_ctc_wide.append(
            get_hyp_prob_ctc(curr_hyp_ctc[i], conformer_ctc_wide, logprobs_ctc_wide[i], encoded_len[i])
        )
        curr_scores_ctc_wide.append(
            get_hyp_prob_ctc(curr_hyp_ctc_wide[i], conformer_ctc_wide, logprobs_ctc_wide[i], encoded_len[i])
        )
        
        scores_las.append(curr_scores_las)
        scores_ctc.append(curr_scores_ctc)
        scores_ctc_wide.append(curr_scores_ctc_wide)
        

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 479/479 [17:14<00:00,  2.16s/it]


In [87]:
hyps_mbr = []
for i in tqdm(range(len(refs))):
    hyps = (hyps_las[i], hyps_ctc[i], hyps_ctc_wide[i])
    scores = (scores_las[i], scores_ctc[i], scores_ctc_wide[i])
    if None in scores_ctc[i]:
        hyps = (hyps_ctc[i], hyps_ctc_wide[i])
        scores = (scores_ctc[i][1:], scores_ctc_wide[i][1:])
    else:
        hyps = (hyps_las[i], hyps_ctc[i], hyps_ctc_wide[i])
        scores = (scores_las[i], scores_ctc[i], scores_ctc_wide[i])
        
    hyps_mbr.append(get_mbr_hyp(hyps, scores))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1916/1916 [00:01<00:00, 1423.62it/s]


In [89]:
compute_wer(refs, hyps_mbr)

0.35146743059158325

### Итоговый WER:
* `LAS: 0.423`
* `ConformerCTC: 0.425`
* `ConformerCTCWide: 0.370`
* `ROVER над тремя моделями: 0.360`
* `MBR над тремя моделями: 0.351`