In [1]:
!pip install pip==24.0

[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m

In [None]:
!pip install torchinfo fairseq transformers huggingface-hub datasets==3.0.1

In [3]:
from fairseq.models import BaseFairseqModel, register_model
from torch import Tensor
from typing import Optional, Dict

from speechgpt.models.whisper.model import HuggingFaceWhisperModel
# импортировать свою модель


# класс для аргументов
class Args:
    pass

@register_model("asr-llm-cascade-model")
class AsrLlmCascadeModel(BaseFairseqModel):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.asr = None
        self.load_models(args)

    def load_models(self, args):
        self.asr = HuggingFaceWhisperModel.build_model(Args, None)
        # self.llm = добавить модель

    @classmethod
    def build_model(cls, args=None, task=None):
        args = args or Args()
        return cls(args)

    def forward(
        self,
        src_tokens: Tensor,
        tgt_tokens: Optional[Tensor] = None,
        src_lengths: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
    ):
        """Форвард пасс метод (может использоваться при обучения, но для генерации
        использовать generate"""
        whisper_output = self.asr(src_tokens, tgt_tokens, src_lengths, incremental_state)
        # добавть работу с моделью llm_output = self.llm(whisper_output, ...)

        # возвращать llm output
        return whisper_output


    def generate(self, input_tokens=None, text=False, skip_special_tokens=True, file=None, **kwargs):

        if input_tokens is None and file is None:
            raise Exception("input_tokens or file must not be None")

        whisper_output = self.asr.generate(input_tokens, text, skip_special_tokens, file, **kwargs)
        # добавть работу с моделью llm_output = self.llm(whisper_output, ...)

        # возвращать llm output
        return whisper_output

In [4]:
cascade = AsrLlmCascadeModel.build_model()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


# Тесты, что работает

In [7]:
# Загружаем тестовый датасет

from datasets import load_dataset
import torch
import soundfile as sf

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]


waveform = torch.tensor(sample['array']).unsqueeze(0)  # Add batch dimension
sampling_rate = sample['sampling_rate']

waveform = waveform.float()

inputs = cascade.asr.processor(waveform.squeeze(0), sampling_rate=sampling_rate, return_tensors="pt")
waveform = inputs['input_features']

sf.write('audio.wav',sample['array'], sampling_rate)

In [8]:
# 1. сгенерировать токены

cascade.generate(waveform)

Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


tensor([[50258, 50259, 50360, 50364,  2221,    13,  2326,   388,   391,   307,
           264, 50244,   295,   264,  2808,  5359,    11,   293,   321,   366,
          5404,   281,  2928,   702, 14943,    13,  6966,   307,  2221,    13,
          2326,   388,   391,   311,  9060,  1570,  1880,   813,   702,  1871,
            13,   634,  5112,   505,   300,   412,   341, 42729,  3196,   295,
           264,  1064,    11,   365,  5272,   293, 12904,  9256,   450, 10539,
           949,   505,    11,  1034,  4680, 10117,   490,  3936,   293,  1080,
          3542,  5160,   881, 26336,   281,   264,  1575,    13,   634,   575,
         12525, 22618,  1968,  6144, 35617, 20084,  1756,   311,   589,   307,
           534, 10281,   934,   439,    11]])

In [9]:
# 2. сгенерировать текст

cascade.generate(waveform, text=True)

[" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"]

In [11]:
# 3. форвард пасс

in_features = cascade.asr.processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
prompt_ids = torch.tensor(cascade.asr.processor.tokenizer.prefix_tokens).unsqueeze(0)
cascade(src_tokens=in_features, tgt_tokens=prompt_ids)

(tensor([[[ 2.4991,  1.7260, -0.9168,  ...,  0.1925,  2.3552, -0.5627],
          [-0.7743, -0.2981, -2.3134,  ..., -2.1932, -2.1076, -3.5427]]],
        grad_fn=<UnsafeViewBackward0>),
 None,
 None,
 None)

In [12]:
# 4. сгенерировать из текст аудиофайла

cascade.generate(file='audio.wav', text=True)

[" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"]

In [13]:
# 5. сгенерировать токены из аудиофайла

cascade.generate(file='audio.wav')

tensor([[50258, 50259, 50360, 50364,  2221,    13,  2326,   388,   391,   307,
           264, 50244,   295,   264,  2808,  5359,    11,   293,   321,   366,
          5404,   281,  2928,   702, 14943,    13,  6966,   307,  2221,    13,
          2326,   388,   391,   311,  9060,  1570,  1880,   813,   702,  1871,
            13,   634,  5112,   505,   300,   412,   341, 42729,  3196,   295,
           264,  1064,    11,   365,  5272,   293, 12904,  9256,   450, 10539,
           949,   505,    11,  1034,  4680, 10117,   490,  3936,   293,  1080,
          3542,  5160,   881, 26336,   281,   264,  1575,    13,   634,   575,
         12525, 22618,  1968,  6144, 35617, 20084,  1756,   311,   589,   307,
           534, 10281,   934,   439,    11]])