In [1]:
! pip install --no-deps git+https://github.com/salute-developers/GigaAM.git pywer hydra-core kenlm flashlight-text

Collecting git+https://github.com/salute-developers/GigaAM.git
  Cloning https://github.com/salute-developers/GigaAM.git to /tmp/pip-req-build-ckn_yz_w
  Running command git clone --filter=blob:none --quiet https://github.com/salute-developers/GigaAM.git /tmp/pip-req-build-ckn_yz_w
  Resolved https://github.com/salute-developers/GigaAM.git to commit 6a8b511f753670ed38af6529bb89bbdc2191ba6a
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import os
import matplotlib.pyplot as plt
import torch
from typing import List
import pywer
from tqdm import tqdm
import omegaconf
import gigaam
import IPython
import random
import numpy as np

In [3]:
! pip install datasets==3.6.0



In [4]:
from datasets import load_dataset
from torch.utils.data import DataLoader

In [5]:
! pip list | grep datasets

datasets                              3.6.0
tensorflow-datasets                   4.9.9
vega-datasets                         0.9.0


In [6]:
# fleurs = load_dataset("google/fleurs", "ru_ru", split="train")
fleurs = load_dataset("google/fleurs", "ru_ru")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [7]:
fleurs

DatasetDict({
    train: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 2562
    })
    validation: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 356
    })
    test: Dataset({
        features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
        num_rows: 775
    })
})

In [8]:
fleurs['train']

Dataset({
    features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
    num_rows: 2562
})

In [9]:
fleurs['train']['raw_transcription'][0]

'Эта идея пришла из Китая, где излюбленным цветком был цвет сливы.'

In [10]:
fleurs['train']['transcription'][0]

'эта идея пришла из китая где излюбленным цветком был цвет сливы'

In [11]:
fleurs['train']['path'][0]

'/root/.cache/huggingface/datasets/downloads/extracted/a2d3fe0258fe892f72f60a3a351ce767d90291abb07f8068cb630764d03dc666/10002203722064200187.wav'

In [12]:
from torch.utils.data import Dataset

In [13]:
class AudioDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        return {
            'id': sample['id'],
            'num_samples': torch.tensor(sample['num_samples'], dtype=torch.int32),
            'path': sample['path'],
            'audio': torch.tensor(sample['audio']['array'], dtype=torch.float32),
            'transcription': sample['transcription'],
            'raw_transcription': sample['raw_transcription'],
            'gender': sample['gender'],
            'lang_id': torch.tensor(sample['lang_id'], dtype=torch.int32),
            'language': sample['language'],
            'lang_group_id': torch.tensor(sample['lang_group_id'], dtype=torch.int32)
        }

In [14]:
def collate_fn(batch):
    # Для полей с разной длиной (например, audio) нужно добавить паддинг
    audio = [item['audio'] for item in batch]
    audio_padded = torch.nn.utils.rnn.pad_sequence(audio, batch_first=True)

    # return {
    #     'id': [item['id'] for item in batch],
    #     'num_samples': torch.stack([item['num_samples'] for item in batch]),
    #     'path': [item['path'] for item in batch],
    #     'audio': audio_padded,
    #     'transcription': [item['transcription'] for item in batch],
    #     'raw_transcription': [item['raw_transcription'] for item in batch],
    #     'gender': [item['gender'] for item in batch],
    #     'lang_id': torch.stack([item['lang_id'] for item in batch]),
    #     'language': [item['language'] for item in batch],
    #     'lang_group_id': torch.stack([item['lang_group_id'] for item in batch])
    # }
    return audio_padded, torch.stack([item['num_samples'] for item in batch]), [item['transcription'] for item in batch]

In [15]:
# #DEBUG
# def my_collate_fn(data):
#     # TODO: Implement your function
#     # But I guess in your case it should be:
#     # return tuple(data)
#     wav_batch = []
#     wav_lengths = []
#     texts = []

#     for item in data:
#       wav_batch.append(torch.tensor(item['audio']['array'], dtype=torch.float))
#       texts.append(item['transcription'])

#     wav_lengths = torch.tensor(wav_lengths, dtype=torch.float)

#     wav_batch = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True, padding_value=0)

#     # return torch.stack(wav_batch, dim=0), wav_lengths, texts
#     return wav_batch, wav_lengths, texts

# fleurs_train = fleurs['train'].remove_columns(['id', 'path', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'])
fleurs_train = fleurs['train']
# dataloader = DataLoader(fleurs_train, batch_size=32, collate_fn=my_collate_fn)

In [16]:
fleurs_train

Dataset({
    features: ['id', 'num_samples', 'path', 'audio', 'transcription', 'raw_transcription', 'gender', 'lang_id', 'language', 'lang_group_id'],
    num_rows: 2562
})

In [17]:
custom_dataset = AudioDataset(fleurs_train)
dataloader = DataLoader(
    custom_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=1
)

In [18]:
for batch in dataloader:
  wav_batch, wav_lengths, texts = batch
  print(wav_batch.shape)
  break

torch.Size([1, 99840])


# Model

In [19]:
from gigaam import GigaAMASR

CACHE_DIR = os.path.expanduser("~/.cache/gigaam")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name, model_path = gigaam._download_model('ctc', CACHE_DIR)

ckpt = torch.load(model_path, map_location='cpu', weights_only=False)

ckpt["cfg"].encoder.flash_attn = False
model = GigaAMASR(ckpt['cfg'])

model.load_state_dict(ckpt["state_dict"], strict=False)
model = model.eval()

if device.type != "cpu":
  model.encoder = model.encoder.half()

model = model.to(device)

In [20]:
param = next(model.parameters())
print('Model device:', param.device)
print('Model dtype:', param.dtype)
print('Parameters count:', sum(p.numel() for p in model.parameters()))

DEVICE = param.device

Model device: cpu
Model dtype: torch.float32
Parameters count: 232585762


In [21]:
#model

In [22]:
wav_path_example = '/root/.cache/huggingface/datasets/downloads/extracted/a2d3fe0258fe892f72f60a3a351ce767d90291abb07f8068cb630764d03dc666/train/10002203722064200187.wav'
model.transcribe(wav_path_example)

  return torch.frombuffer(audio, dtype=torch.int16).float() / 32768.0


'эта идея пришла из китая где излюбленным цветком был цвет сливы'

In [23]:
IPython.display.Audio(wav_path_example)

In [24]:
def fix_torch_seed(seed: int = 42):
    # Python
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    # NumPy
    np.random.seed(seed)

    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU

    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

    print(f"✅ Random seed fixed to {seed}")

# Example usage:
fix_torch_seed(1234)

✅ Random seed fixed to 1234


In [25]:
# dataloader = DataLoader(fleurs, batch_size=32)

In [26]:
def get_gigaam_logprobs(model, wav_batch, wav_lengths, return_transcriptions=False):
    wav_batch = wav_batch.to(model._device)
    wav_lengths = wav_lengths.to(model._device)

    encoded, encoded_len = model.forward(wav_batch, wav_lengths)

    logprobs = model.head(encoded)

    if return_transcriptions:
        transcriptions = model.decoding.decode(model.head, encoded, encoded_len)
        return logprobs, encoded_len, transcriptions
    else:
        return logprobs, encoded_len

In [27]:
def calculate_wer_on_dataset(model, dataloader, batch_size=8, num_workers=2, return_transcriptions=False):
  # dataloader = torch.utils.data.DataLoader(
  #     dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers,
  # )
  references = []
  hypotheses = []

  for batch in tqdm(dataloader):
      wav_batch, wav_lengths, texts = batch
      logprobs, lengths, transcriptions = get_gigaam_logprobs(model, wav_batch, wav_lengths, return_transcriptions=True)
      references.extend(texts)
      hypotheses.extend(transcriptions)

  wer = pywer.wer(references, hypotheses)
  if return_transcriptions:
    return wer, references, hypotheses

  return wer

In [28]:
wer, references, hypotheses = calculate_wer_on_dataset(model, dataloader, return_transcriptions=True)
print('\nWER on farfield: ', wer)

 17%|█▋        | 434/2562 [33:52<2:46:07,  4.68s/it]


KeyboardInterrupt: 

# Beamsearch alg

In [None]:
import kenlm
from torchaudio.models.decoder import ctc_decoder
import editdistance

LM_PATH = '/content/lm_50x50.binary'
LEXICON_PATH = '/content/lexicon.txt'
lm_model = kenlm.LanguageModel(LM_PATH)

In [None]:
from torchaudio.models.decoder import ctc_decoder


TOKENS = ckpt['cfg'].decoding.vocabulary + ['|']

LM_WEIGHT = 2.0
WORD_SCORE = -0.5

beam_search_decoder = ctc_decoder(
    lexicon='/content/lexicon.txt',
    tokens=TOKENS,
    nbest=10,
    beam_size=30,
    sil_token=' ',
    blank_token='|',
    lm='/content/lm_50x50.binary',
    lm_weight=LM_WEIGHT,
    word_score=WORD_SCORE,
)

In [None]:
def decode_indices(labels, model):
    return "".join(model.decoding.tokenizer.decode(labels.cpu().tolist()))

class BeamSearchEvaluator:
    def __init__(self, model, beam_search_decoder, lm_model=None):
        self.model = model
        self.beam_search_decoder = beam_search_decoder
        self.lm_model = lm_model

    def set_beam_search_decoder(self, beam_search_decoder):
        self.beam_search_decoder = beam_search_decoder

    def set_lm_model(self, lm_model):
        self.lm_model = lm_model

    def evaluate(self, dataset, batch_size=8, num_workers=2):
        # dataloader = torch.utils.data.DataLoader(
        #     dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers
        # )
        dataloader = DataLoader(
            dataset,
            batch_size=1,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=1
        )

        refs, hyps, best_hyps = [], [], []
        rescored_hyps = [] if self.lm_model else None
        n_bests = []

        for wav_batch, wav_lengths, texts in tqdm(dataloader):
            logprob_batch, encoded_len_batch = get_gigaam_logprobs(self.model, wav_batch, wav_lengths)
            beamsearch_result = self.beam_search_decoder(logprob_batch.cpu(), encoded_len_batch.cpu())

            for i, result in enumerate(beamsearch_result):
                ref = texts[i]
                refs.append(ref)

                best_hyp, best_rescored_hyp = self._process_nbest(result, ref)
                hyps.append(best_hyp[0])
                best_hyps.append(best_hyp[1])
                n_bests.append(best_hyp[2])

                if self.lm_model:
                    rescored_hyps.append(best_rescored_hyp)

        return self._compute_metrics(refs, hyps, best_hyps, rescored_hyps, n_bests)

    def _process_nbest(self, result, ref):
        best_distance = float('inf')
        best_hyp = None
        first_hyp = None
        best_rescored_hyp = None
        best_score = -float('inf')
        nbest_list = []

        for j, candidate in enumerate(result):
            curr_tokens = candidate.tokens
            curr_hyp = decode_indices(curr_tokens, self.model)
            nbest_list.append({'hyp': curr_hyp, 'score': candidate.score})

            if j == 0:
                first_hyp = curr_hyp

            distance = editdistance.eval(ref.split(), curr_hyp.split())
            if distance < best_distance:
                best_distance = distance
                best_hyp = curr_hyp

            if self.lm_model:
                score = self.lm_model.score(curr_hyp) / len(curr_hyp)
                if score > best_score:
                    best_score = score
                    best_rescored_hyp = curr_hyp

        return (first_hyp, best_hyp, nbest_list), best_rescored_hyp

    def _compute_metrics(self, refs, hyps, best_hyps, rescored_hyps, n_bests):
        wer = pywer.wer(refs, hyps)
        oracle_wer = pywer.wer(refs, best_hyps)

        output = {
            'wer': wer,
            'oracle_wer': oracle_wer,
            'references': refs,
            'hypotheses': hyps,
            'oracle_hypotheses': best_hyps,
            'n_bests': n_bests
        }

        if self.lm_model:
            rescored_wer = pywer.wer(refs, rescored_hyps)
            output.update({
                'rescored_wer': rescored_wer,
                'rescored_hypotheses': rescored_hyps,
            })

        return output

In [None]:
evaluator = BeamSearchEvaluator(model, beam_search_decoder, lm_model=lm_model)
res = evaluator.evaluate(custom_dataset)
print('\nBeam search WER: ', res['wer'])
print('\nBeam search OracleWER: ', res['oracle_wer'])
print('\nGreedy WER: ', wer_noisy)

# Тестирование качества работы декодера при варьировании параметра beam_size (количство рассматриваемых гипотиз)

In [None]:
BEAM_SIZES = [5, 10, 20, 50]
wer_results = []
oracle_wer_results = []
rescored_wer_results = []

for beam_size in BEAM_SIZES:
    beam_search_decoder = ctc_decoder(
        lexicon='/content/lexicon.txt',
        tokens=TOKENS,
        nbest=10,
        beam_size=beam_size,
        sil_token=' ',
        blank_token='|',
        lm='/content/lm_50x50.binary',
        lm_weight=LM_WEIGHT,
        word_score=WORD_SCORE,
    )
    evaluator.set_beam_search_decoder(beam_search_decoder)
    output = evaluator.evaluate(custom_dataset)
    wer_results.append(output['wer'])
    oracle_wer_results.append(output['oracle_wer'])

In [None]:
plt.figure(figsize=(15, 8))
plt.plot(BEAM_SIZES, [wer_noisy] * 4, label='Greedy WER')
plt.plot(BEAM_SIZES, oracle_wer_results, label='Oracle WER')
plt.plot(BEAM_SIZES, wer_results, label='Beam Search WER')
plt.ylabel('WER', fontsize=15)
plt.xlabel('BEAM_SIZE', fontsize=15)
plt.title('Зависимость WER от BEAM_SIZE && N_BEST', fontsize=15)
plt.legend(fontsize=15)
plt.grid()

# Зависимость качества распознавания от параметра n_best

In [None]:
N_BESTS = [2, 5, 10, 20, 50]
wer_results = []
oracle_wer_results = []
rescored_wer_results = []

for n_best in N_BESTS:
    beam_search_decoder = ctc_decoder(
        lexicon='/content/lexicon.txt',
        tokens=TOKENS,
        nbest=n_best,
        beam_size=50,
        sil_token=' ',
        blank_token='|',
        lm='/content/lm_50x50.binary',
        lm_weight=LM_WEIGHT,
        word_score=WORD_SCORE,
    )
    evaluator.set_beam_search_decoder(beam_search_decoder)
    output = evaluator.evaluate(custom_dataset)
    wer_results.append(output['wer'])
    oracle_wer_results.append(output['oracle_wer'])
    rescored_wer_results.append(output['rescored_wer'])

In [None]:
plt.figure(figsize=(15, 8))
plt.plot(N_BESTS, wer_results, label='WER')
plt.plot(N_BESTS, oracle_wer_results, label='Oracle WER')
plt.plot(N_BESTS, rescored_wer_results, label='WER with rescoring')

plt.ylabel('WER', fontsize=15)
plt.xlabel('N_BEST', fontsize=15)
plt.title('Зависимость WER от BEAM_SIZE && N_BEST', fontsize=15)
plt.legend(fontsize=15)
plt.grid()