### Inference of Speech-Language Model

В этой домашней работе предлагается ознакомиться с тем, как работает инференс языковой модели с обученным аудио выходом. В примере используется модель Mini-Omni, для большего понимания принципа работы модели вы можете изучить оригинальную статью: [arxiv](https://arxiv.org/abs/2408.16725). Так же можете обратиться к оригинальному коду: [github](https://github.com/gpt-omni/mini-omni). 

В ноутбуке [tts-hw.ipynb](tts-hw.ipynb) требуется реализовать несколько полезных фичей для поддержки батчевого инференса и использования conversational модели в режиме TTS (с форсированием текстового выхода). Не меняйте код в используемых модулях, постарайтесь не менять основной код в ноутбуке вне блоков `YOUR CODE HERE`.

Задание:
1. **3 балла**: Реализовать kv-кеширование. Возможно, вы захотите заранее предусмотреть, чтобы ваша реализация работала и для пункта 2.
2. **3 балла**: Реализовать батчевый инференс. Продумайте, как лучше добавлять паддинги при составлении входного батча.
3. **4 балла**: Реализовать инференс с форсированием текстового выхода, чтобы можно было использовать модель как модель TTS. Вы можете добавить свои примеры текстов для озвучки. Такой режим так же должен работать в батчевом инференсе, продумайте, как правильно дополнить создание батча в этом случае. 

Это задание вы можете выполнять где вам удобно - локально, с использованием google-colab (для этого в ноутбуке можно склонировать этот репозиторий и добавить в `sys.path` пути до используемых модулей) или kaggle-notebooks. 



In [25]:
!pip install snac soundfile omegaconf tokenizers gdown ipython;



In [26]:
from typing import List, Tuple, Optional, Any
from pathlib import Path
import math
from dataclasses import dataclass
from IPython.display import Audio, display

import soundfile as sf
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from omegaconf import OmegaConf as om
from tqdm import tqdm
from snac import SNAC

from tokenizer import Tokenizer
from snac_utils import reconscruct_snac, reconstruct_tensors
from model import apply_rope, build_rope_cache, RMSNorm, LLaMAMLP
from download_model import download_model

In [31]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
out_dir = Path(f"outputs")
if not out_dir.exists():
    out_dir.mkdir(exist_ok=True)
ckpt_dir = Path(f"checkpoint")

download_model(ckpt_dir)

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

In [45]:
class KVCache(nn.Module):
    def __init__(
        self,
        k_shape: Tuple[int, int, int, int],
        v_shape: Tuple[int, int, int, int],
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        super().__init__()
        self.register_buffer(
            "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
        )
        self.register_buffer(
            "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
        )

    def forward(
        self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    
        # move the buffer to the activation dtype for when AMP is used
        self.k = self.k.to(k.dtype)
        self.v = self.v.to(v.dtype)
    
        # k, v: [batch_size, n_heads, seq_len, head_size]
        # input_pos: [batch_size, seq_len]
        # self.k, self.v: [batch_size, n_heads, max_seq_len, head_size]
    
        # Создаём индексы для scatter
        batch_size, n_heads, seq_len, head_size = k.shape
    
        # Расширяем input_pos для всех heads и head_size
        # input_pos: [batch_size, seq_len] -> [batch_size, n_heads, seq_len, head_size]
        input_pos_expanded = input_pos.unsqueeze(1).unsqueeze(-1).expand(batch_size, n_heads, seq_len, head_size)
    
        # Используем scatter для обновления кеша
        self.k.scatter_(2, input_pos_expanded, k)
        self.v.scatter_(2, input_pos_expanded, v)
    
        # Возвращаем кеш до максимальной использованной позиции
        max_pos = input_pos.max().item()
    
        return self.k[:, :, :max_pos + 1], self.v[:, :, :max_pos + 1]

    def reset_parameters(self) -> None:
        torch.nn.init.zeros_(self.k)
        torch.nn.init.zeros_(self.v)
        

def build_mask_cache(
    max_seq_length: int, device: Optional[torch.device] = None
) -> torch.Tensor:
    ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
    return torch.tril(ones)

In [33]:
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd: int, n_head: int, n_query_groups: int, head_size: int, add_qkv_bias: bool, rope_n_elem: int, bias: bool=False) -> None:
        super().__init__()
        shape = (n_head + 2 * n_query_groups) * head_size
        # key, query, value projections for all heads, but in a batch
        self.attn = nn.Linear(n_embd, shape, bias=add_qkv_bias)
        # output projection
        # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
        self.proj = nn.Linear(
            head_size * n_head, n_embd, bias=bias
        )
        # disabled by default
        self.kv_cache: Optional[KVCache] = None
        self.n_head = n_head
        self.n_query_groups = n_query_groups
        self.head_size = head_size
        self.rope_n_elem = rope_n_elem

    def forward(
        self,
        x: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        input_pos: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        B, T, C = (
            x.size()
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        qkv = self.attn(x)

        # assemble into a number of query groups to support MHA, MQA and GQA together (see `n_query_groups`)
        q_per_kv = self.n_head // self.n_query_groups
        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
        qkv = qkv.view(
            B, T, self.n_query_groups, total_qkv, self.head_size
        )
        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)

        # split batched computation into three
        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

        # maybe repeat k and v if for the non multi-head attention cases
        # training: flash attention requires it
        # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
        if self.n_query_groups != self.n_head and (
            input_pos is None or self.n_query_groups != 1
        ):
            k = k.expand(
                B, self.n_query_groups, q_per_kv, T, self.head_size
            )
            v = v.expand(
                B, self.n_query_groups, q_per_kv, T, self.head_size
            )

        q = q.reshape(B, -1, T, self.head_size)  # (B, nh_q, T, hs)
        k = k.reshape(B, -1, T, self.head_size)  # (B, nh_k, T, hs)
        v = v.reshape(B, -1, T, self.head_size)  # (B, nh_v, T, hs)

        q_roped = apply_rope(q[..., : self.rope_n_elem], cos, sin)
        k_roped = apply_rope(k[..., : self.rope_n_elem], cos, sin)
        q = torch.cat((q_roped, q[..., self.rope_n_elem :]), dim=-1)
        k = torch.cat((k_roped, k[..., self.rope_n_elem :]), dim=-1)

        if input_pos is not None:
            if not isinstance(self.kv_cache, KVCache):
                raise TypeError("You need to call `gpt.set_kv_cache()`")
            k, v = self.kv_cache(input_pos, k, v)

        y = self.scaled_dot_product_attention(q, k, v, mask)

        y = y.reshape(
            B, T, self.head_size * self.n_head
        )  # re-assemble all head outputs side by side

        # output projection
        return self.proj(y)

    def scaled_dot_product_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        scale = 1.0 / math.sqrt(self.head_size)
        y = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
        )
        return y.transpose(1, 2)
    
    def build_kv_cache(
        self,
        batch_size: int,
        max_seq_length: int,
        rope_cache_length: Optional[int] = None,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> KVCache:
        heads = 1 if self.n_query_groups == 1 else self.n_head
        v_shape = (batch_size, heads, max_seq_length, self.head_size)
        if rope_cache_length is None:
            if self.rotary_percentage != 1.0:
                raise TypeError(
                    "Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
                )
            k_shape = v_shape
        else:
            k_shape = (
                batch_size,
                heads,
                max_seq_length,
                rope_cache_length + self.head_size - self.rope_n_elem,
            )
        return KVCache(k_shape, v_shape, device=device, dtype=dtype)

In [34]:
class Block(nn.Module):
    def __init__(self, n_embd: int, norm_eps: float, n_head: int, n_query_groups: int, head_size: int, intermediate_size: int, add_qkv_bias: bool, rope_n_elem: int, bias: bool) -> None:
        super().__init__()
        self.norm_1 = RMSNorm(n_embd, eps=norm_eps)
        self.attn = CausalSelfAttention(n_embd, n_head, n_query_groups, head_size, add_qkv_bias, rope_n_elem, bias)
        self.norm_2 = RMSNorm(n_embd, eps=norm_eps)
        self.mlp = LLaMAMLP(n_embd, intermediate_size, bias)

    def forward(
        self,
        x: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        input_pos: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x_normed = self.norm_1(x)
        attention_output = self.attn(x_normed, cos, sin, mask, input_pos)

        x = attention_output + x
        x = self.mlp(self.norm_2(x)) + x
        return x

In [46]:
class GPT(nn.Module):
    def __init__(
            self,
            n_q: int,
            n_embd: int,
            n_layer: int,
            n_head: int,
            n_query_groups: int,
            head_size: int,
            intermediate_size: int,
            add_qkv_bias: int,
            bias: bool,
            block_size: int,
            norm_eps: float,
            tie_word_embeddings: bool,
            scale_embeddings: bool,
            text_vocab_size: int,
            audio_vocab_size: int,
            lm_head_bias: bool,
            rotary_percentage: float,
            rope_condense_ratio: int,
            rope_base: int,
            **kwargs
            ) -> None:
        super().__init__()
        self.lm_heads = nn.ModuleList([nn.Linear(n_embd, text_vocab_size, bias=lm_head_bias)] + [nn.Linear(n_embd, audio_vocab_size, bias=lm_head_bias) for _ in range(n_q)])
        self.embeds = nn.ModuleList([nn.Embedding(text_vocab_size, n_embd)] + [nn.Embedding(audio_vocab_size, n_embd) for _ in range(n_q)])
        self.ln = RMSNorm(n_embd, eps=norm_eps)
        self.rope_n_elem = int(rotary_percentage * head_size)
        self.transformer = nn.ModuleList(Block(
            n_embd, 
            norm_eps, 
            n_head, 
            n_query_groups, 
            head_size, 
            intermediate_size, 
            add_qkv_bias, 
            self.rope_n_elem, 
            bias
            ) for _ in range(n_layer))
        
        self.block_size = block_size
        self.scale_embeddings = scale_embeddings
        self.audio_vocab_size = audio_vocab_size
        self.text_vocab_size = text_vocab_size
        self.rope_condense_ratio = rope_condense_ratio
        self.rope_base = rope_base
        self.n_embd = n_embd
        self.n_q = n_q
        self.max_seq_length = block_size
        self.mask_cache: Optional[torch.Tensor] = None
        if tie_word_embeddings:
            for lm_head, emb in zip(self.lm_heads, self.embeds):
               lm_head.weight = emb.weight

    @property
    def max_seq_length(self) -> int:
        return self._max_seq_length

    @max_seq_length.setter
    def max_seq_length(self, value: int) -> None:
        """
        When doing inference, the sequences used might be shorter than the model's context length.
        This allows setting a smaller number to avoid allocating unused memory
        """
        if value > self.block_size:
            raise ValueError(
                f"Cannot attend to {value}, block size is only {self.block_size}"
            )
        self._max_seq_length = value
        if not hasattr(self, "cos"):
            # first call
            cos, sin = self.rope_cache()
            self.register_buffer("cos", cos, persistent=False)
            self.register_buffer("sin", sin, persistent=False)
        # override
        elif value != self.cos.size(0):
            self.cos, self.sin = self.rope_cache(device=self.cos.device)
        # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
        # if the kv cache is expected

    def reset_parameters(self) -> None:
        # Trigger resetting the rope-cache
        self.cos, self.sin = self.rope_cache(device=self.cos.device)

    def _init_weights(self, module: nn.Module) -> None:
        """Meant to be used with `gpt.apply(gpt._init_weights)`."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        input_ids: torch.Tensor,
        input_pos: Optional[torch.Tensor] = None,
    ) -> List[torch.Tensor]:
        bs, _, T = input_ids.shape
    
        if self.max_seq_length < T:
            raise ValueError(
                f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
            )
    
        if input_pos is not None:  # use the kv cache
            if self.mask_cache is None:
                raise TypeError("You need to call `gpt.set_kv_cache()`")
    
            # Для cos/sin берём значения по позициям
            # Предполагаем, что все элементы батча имеют одинаковые позиции
            # (сдвинуты на одинаковое количество паддингов)
            # Берём позиции из первого элемента батча
            positions = input_pos[0]  # [seq_len]
    
            cos = self.cos[positions].unsqueeze(0).expand(bs, -1, -1)  # [bs, seq_len, rope_dim]
            sin = self.sin[positions].unsqueeze(0).expand(bs, -1, -1)

            # Создаём маску    
            seq_len = input_pos.shape[1]
    
            if seq_len > 1:
                # Для префилла - causal маска
                mask = self.mask_cache[None, None, :seq_len, :seq_len]
            else:
                # Для генерации по одному токену маска не нужна
                mask = None
        else:
            cos = self.cos[None, :T, :].repeat(bs, 1, 1)
            sin = self.sin[None, :T, :].repeat(bs, 1, 1)
    
            mask = None
    
        x = 0
    
        for i, emb in enumerate(self.embeds):
            x += emb(input_ids[:, i])

        x = x / (self.n_q + 1)
    
        if self.scale_embeddings:    
            x = x * (self.n_embd**0.5)
    
        for block in self.transformer:
            x = block(x, cos, sin, mask, input_pos)
    
        x_ori = x
        x_ori = self.ln(x_ori)
    
        logits = []
    
        for lm_head in self.lm_heads:
            logits.append(lm_head(x_ori))
    
        return logits

    def rope_cache(
        self, device: Optional[torch.device] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        return build_rope_cache(
            seq_len=self.max_seq_length,
            n_elem=self.rope_n_elem,
            device=device,
            condense_ratio=self.rope_condense_ratio,
            base=self.rope_base,
        )

    def set_kv_cache(
        self,
        batch_size: int,
        rope_cache_length: Optional[int] = None,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        if rope_cache_length is None:
            rope_cache_length = self.cos.size(-1)
        max_seq_length = self.max_seq_length

        # initialize the kv cache for all blocks
        for block in self.transformer:
            block.attn.kv_cache = block.attn.build_kv_cache(
                batch_size, max_seq_length, rope_cache_length, device, dtype
            )

        if self.mask_cache is None or self.mask_cache.size(-1) != max_seq_length:
            # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
            # for the kv-cache support (only during inference), we only create it in that situation
            self.mask_cache = build_mask_cache(max_seq_length, device)

    def clear_kv_cache(self) -> None:
        self.mask_cache = None
        for block in self.transformer:
            block.attn.kv_cache = None


In [36]:
@dataclass
class Config:
    n_q: int = 7
    batch_size: int = 1
    text_vocabsize: int = 151936
    text_specialtokens: int = 64
    audio_vocabsize: int = 4096
    audio_specialtokens: int = 64
    max_seq_length: int = 1024
    temperature: float = 0.9
    top_k: int = 1
    top_p: float = 1.0
    device: str = "cuda:0"
    
    def __post_init__(self):
        self.text_eos = self.text_vocabsize
        self.text_pad = self.text_vocabsize + 1
        self.text_input_bos = self.text_vocabsize + 2
        self.text_answer_bos = self.text_vocabsize + 3

        self.audio_eos = self.audio_vocabsize
        self.audio_pad = self.audio_vocabsize + 1
        self.audio_input_bos = self.audio_vocabsize + 2
        self.audio_answer_bos = self.audio_vocabsize + 3


In [37]:
def load_model(ckpt_dir: Path, inference_config: Config):
    snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(inference_config.device)
    text_tokenizer = Tokenizer(ckpt_dir)
    model_config = om.load(ckpt_dir / "model_config.yaml")
    model = GPT(n_q=inference_config.n_q, **model_config)
    state_dict = torch.load(ckpt_dir / "model.pth")
    model.load_state_dict(state_dict, strict=True)
    model.to(inference_config.device).eval()
    return model, text_tokenizer, snacmodel

In [38]:

def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    # Example:
    # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
    # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
    sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # Keep at least 1 token always to prevent the case where no token is selected
    # In this case the most probable one is always kept
    sorted_indices_to_remove[..., -1:] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(
        -1, sorted_indices, sorted_indices_to_remove
    )
    logits = logits.masked_fill(indices_to_remove, float("-inf"))
    return logits


def sample(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
) -> torch.Tensor:
    if top_p < 0.0 or top_p > 1.0:
        raise ValueError(f"top_p must be in [0, 1], got {top_p}")
    # optionally crop the logits to only the top k options
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))
        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
    # optionally scale the logits and sample from a probability distribution
    if temperature > 0.0 or top_p > 0.0:
        if temperature > 0.0:
            logits = logits / temperature
        # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
        if 0.0 < top_p < 1.0:
            logits = sample_top_p(logits, top_p)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1)
    return torch.argmax(logits, dim=-1, keepdim=True)


def next_token(
    model: GPT,
    input_ids: torch.Tensor,
    input_pos: torch.Tensor,
    **kwargs: Any,
) -> torch.Tensor:
    all_logits = model(input_ids, input_pos=input_pos)
    new_input_ids = []
    for logits in all_logits:
        new_input_ids.append(sample(logits[:, -1, :], **kwargs))
    return torch.cat(new_input_ids, dim=1)


@torch.inference_mode()
def generate(
    model: GPT,
    input_ids: torch.Tensor,
    speech_input_ids: torch.Tensor,
    *,
    forced_input_ids: Optional[torch.Tensor] = None,
    max_returned_tokens: int = 2048,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id_a: Optional[int] = None,
    eos_id_t: Optional[int] = None,
    pad_id_a: Optional[int] = None,
    pad_id_t: Optional[int] = None,
    use_kv_cache: bool = True,
):
    device = input_ids.device
    bs, T = input_ids.shape
    
    if speech_input_ids is None:
        speech_input_ids = torch.full(
            (bs, model.n_q, T), pad_id_a, dtype=torch.long, device=device
        )
    input_ids = torch.cat([input_ids[:, None], speech_input_ids], 1)
    generated_input_ids = torch.zeros((bs, 1 + model.n_q, 0), dtype=torch.long, device=device)
    
    found_eos = torch.zeros((bs, 1 + model.n_q), device=device, dtype=torch.bool)
    pad_token = torch.tensor(
        [pad_id_t] + [pad_id_a] * model.n_q,
        device=device,
        dtype=torch.long,
    )[None, :].repeat(bs, 1)
    eos_token = torch.tensor(
        [eos_id_t] + [eos_id_a] * model.n_q,
        device=device,
        dtype=torch.long,
    )[None, :].repeat(bs, 1)
    
    
    if use_kv_cache:
        # Инициализируем kv-cache с размером батча
        model.set_kv_cache(batch_size=bs, device=device)

        # Создаём input_pos для префилла (обработки всей входной последовательности)
        # Это индексы позиций от 0 до T-1 для каждого элемента батча
        input_pos = torch.arange(0, T, device=device).unsqueeze(0).expand(bs, -1)
    else:
        input_pos = None
    
    pred_ids = next_token(
        model,
        input_ids,
        input_pos=input_pos,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )
    for step in tqdm(range(max_returned_tokens - T)):
        if forced_input_ids is not None and forced_input_ids.shape[-1] > 1:
            # В режиме TTS мы форсируем текстовый выход
            # Берём следующий токен из forced_input_ids вместо сэмплирования

            # Определяем, какой токен forced sequence мы должны использовать
            # step - это текущий шаг генерации
            if step < forced_input_ids.shape[-1]:
                # Заменяем текстовую часть pred_ids на forced токен
                # pred_ids имеет shape [bs, 1 + n_q], где первый элемент - текст
                pred_ids[:, 0] = forced_input_ids[:, step]
            else:
                # Если forced_input_ids закончились, то текстовая часть должна быть EOS
                # Это означает, что генерация текста завершена, продолжаем только аудио
                pass
                
        pred_ids[found_eos] = pad_token[found_eos]
        found_eos = torch.logical_or(found_eos, pred_ids == eos_token)
        if found_eos.all():
            break
        
        generated_input_ids = torch.cat([generated_input_ids, pred_ids[..., None]], dim=-1)
        
        if use_kv_cache:
            input_pos = input_pos[:, -1:] + 1

        pred_ids = next_token(
            model,
            pred_ids[..., None] if use_kv_cache else torch.cat([input_ids, generated_input_ids], dim=-1),
            input_pos=input_pos,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )
    return generated_input_ids[:, 0], generated_input_ids[:, 1:]

In [39]:
def collate(input_ids: List[torch.Tensor], text_pad: int, speech_input_ids: List[torch.Tensor], audio_pad: int, forced_input_ids: Optional[List[torch.Tensor]] = None):

    # Паддим text input_ids до максимальной длины в батче слева
    # Паддинг слева важен для корректной работы causal attention
    input_ids = pad_sequence(
        [ids.flip(0) for ids in input_ids],  # Переворачиваем для pad_sequence
        batch_first=True,
        padding_value=text_pad
    ).flip(1)  # Возвращаем обратно - теперь паддинг слева

    # Паддим speech input_ids аналогично (для каждого из n_q слоёв)
    # speech_input_ids - список тензоров размера [n_q, seq_len]
    speech_input_ids_padded = []

    for q_idx in range(speech_input_ids[0].shape[0]):
        q_layer = pad_sequence(
            [ids[q_idx].flip(0) for ids in speech_input_ids],
            batch_first=True,
            padding_value=audio_pad
        ).flip(1)
        speech_input_ids_padded.append(q_layer)
    speech_input_ids = torch.stack(speech_input_ids_padded, dim=1)  # [bs, n_q, seq_len]

    # Если есть forced_input_ids (для TTS режима), паддим их тоже
    if forced_input_ids is not None:
        forced_input_ids = pad_sequence(
            [ids.flip(0) for ids in forced_input_ids],
            batch_first=True,
            padding_value=text_pad
        ).flip(1)

    return input_ids, forced_input_ids, speech_input_ids
    

def get_input_ids(text_tokenizer, config: Config, text: str, text_answer: Optional[str] = None):
    text_tokens = text_tokenizer.encode(text)
    input_ids = torch.tensor([config.text_input_bos] + text_tokens.tolist() + [config.text_eos] + [config.text_answer_bos])
    
    if text_answer:
        text_answer_tokens = text_tokenizer.encode(text_answer)
        forced_input_ids = torch.tensor(text_answer_tokens.tolist() + [config.text_eos])
    else:
        forced_input_ids = None
    
    speech_input_ids = []
    for _ in range(config.n_q):
        speech_input_ids.append(torch.tensor([config.audio_pad] * (len(text_tokens) + 2) + [config.audio_answer_bos]))
    speech_input_ids = torch.stack(speech_input_ids, 0)
    return input_ids, forced_input_ids, speech_input_ids

In [40]:
def unschedule_codes(delayed_codes: torch.Tensor, shifts: List[int]) -> torch.Tensor:
    speech_len = delayed_codes.shape[-1] - sum(shifts)
    codes = torch.zeros_like(delayed_codes[..., :speech_len])

    cum_shift = 0
    for i, shift in enumerate([0] + shifts):
        cum_shift += shift
        codes[:, i] = delayed_codes[:, i, cum_shift : cum_shift + speech_len]
        
    return codes


def postprocess_speech_codes(
    codes: torch.Tensor, codebook_size: int, shifts: List[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    codes = unschedule_codes(codes, shifts)

    masks = ((0 <= codes) & (codes < codebook_size)).all(1)
    codes_clean = [codes_it[:, mask].T for codes_it, mask in zip(codes, masks)]
    lengths = torch.tensor(
        [len(codes_it) for codes_it in codes_clean], dtype=torch.long, device=codes.device
    ).clamp(min=1)
    codes_clean = pad_sequence(codes_clean, batch_first=True).transpose(1, 2)
    if codes_clean.shape[-1] == 0:
        codes_clean = F.pad(codes_clean, (0, 1))
    return codes_clean, lengths

In [41]:
config = Config(
    batch_size=1,
    max_seq_length=1024,
    temperature=0.9,
    top_k=100,
    top_p=1.0,
    device=device
    )
model, text_tokenizer, snacmodel = load_model(ckpt_dir, config)

In [42]:
@torch.inference_mode()
def decode_output(
    texts_in: List[str],
    tokens_t: torch.Tensor,
    tokens_a: torch.Tensor,
    lengths_a: torch.Tensor,
    save_paths: List[Path]
    ) -> List[np.ndarray]:
    audios = []
    for i in range(len(tokens_t)):
        text_tensor = tokens_t[i]
        if config.text_eos in text_tensor:
            text_tensor = text_tensor[:int(torch.nonzero(text_tensor == config.text_eos)[0])]
        text_out = text_tokenizer.decode(text_tensor).strip()

        audiolist = reconscruct_snac(tokens_a[i, :, :int(lengths_a[i].item())].tolist())
        audio = reconstruct_tensors(audiolist)

        audio_hat = snacmodel.decode(audio).squeeze().cpu().numpy()
        sf.write(
            save_paths[i],
            audio_hat,
            24000,
        )
        print(f"input: {texts_in[i]}")
        print(f"out: {text_out}")
        audios.append(audio_hat)
    return audios    

## 0. plain inference (0 points)

Здесь ничего делать не надо, просто запустите, убедитесь, что всё работает

In [43]:
config.batch_size = 1

texts = [
    "What is your name?",
    "How are you feeling today?",
    "Can you describe your surroundings?",
    "What did you do yesterday?",
    "What is your favorite book and why?",
    "How do you make a cup of tea?",
    "What is the weather like today?",
    "Can you explain the concept of time?",
    "Can you tell me a joke?",
    "Tell me the history of Civil War in the US"
]

audios = []
for idx in range(len(texts)):
    input_ids, _, speech_input_ids = get_input_ids(text_tokenizer, config, texts[idx])
    tokens_t, tokens_a = generate(
        model,
        input_ids.to(config.device).unsqueeze(0),
        speech_input_ids.to(config.device).unsqueeze(0),
        max_returned_tokens=config.max_seq_length,
        temperature=config.temperature,
        top_k=config.top_k,
        top_p=config.top_p,
        eos_id_t=config.text_eos,
        eos_id_a=config.audio_eos,
        pad_id_t=config.text_pad,
        pad_id_a=config.audio_pad,
        use_kv_cache=False,
    )
    tokens_a, lengths = postprocess_speech_codes(tokens_a, config.audio_vocabsize, [1] * (config.n_q - 1))
    save_paths = [(out_dir / f"0_{idx:02d}.wav")]
    audios.extend(decode_output([texts[idx]], tokens_t, tokens_a, lengths, save_paths))
    
#for audio in audios:
display(Audio(audios[0], rate=24000))

## 1. KV-cache (3 points)

Здесь вам нужно реализовать механизм kv-кеширования, чтобы при каждом следующем шаге декодирования не нужно было пересчитывать kv-фичи для префикса входного тензора. 

Заполните пропуски в классе `KVCache`, `GPT` и первый пропуск в функции `generate`:
```python
if use_kv_cache:
    #####
    # YOUR CODE HERE
    #####
    pass
else:
    input_pos = None
```

In [47]:
config.batch_size = 1

texts = [
    "What is your name?",
    "How are you feeling today?",
    "Can you describe your surroundings?",
    "What did you do yesterday?",
    "What is your favorite book and why?",
    "How do you make a cup of tea?",
    "What is the weather like today?",
    "Can you explain the concept of time?",
    "Can you tell me a joke?",
    "Tell me the history of Civil War in the US"
]

audios = []
for idx in range(len(texts)):
    input_ids, _, speech_input_ids = get_input_ids(text_tokenizer, config, texts[idx])
    
    model.set_kv_cache(batch_size=config.batch_size, device=config.device)
    
    tokens_t, tokens_a = generate(
        model,
        input_ids.to(config.device).unsqueeze(0),
        speech_input_ids.to(config.device).unsqueeze(0),
        max_returned_tokens=config.max_seq_length,
        temperature=config.temperature,
        top_k=config.top_k,
        top_p=config.top_p,
        eos_id_t=config.text_eos,
        eos_id_a=config.audio_eos,
        pad_id_t=config.text_pad,
        pad_id_a=config.audio_pad,
        use_kv_cache=True,
    )
    tokens_a, lengths = postprocess_speech_codes(tokens_a, config.audio_vocabsize, [1] * (config.n_q - 1))
    save_paths = [(out_dir / f"1_{idx:02d}.wav")]
    audios.extend(decode_output([texts[idx]], tokens_t, tokens_a, lengths, save_paths))
    
    model.clear_kv_cache()
    
# for audio in audios:
display(Audio(audios[0], rate=24000))

 22%|██▏       | 219/1016 [00:28<01:45,  7.56it/s]


input: What is your name?
out: My search engine research and maintain your device surgery, but is a "Why of 1.


100%|██████████| 1015/1015 [02:14<00:00,  7.55it/s]


input: How are you feeling today?
out: I am a website, "Of


 24%|██▍       | 242/1015 [00:31<01:39,  7.76it/s]


input: Can you describe your surroundings?
out: Sure. AI systems such a major people include being a young.


 47%|████▋     | 474/1015 [01:00<01:09,  7.79it/s]


input: What did you do yesterday?
out: I'm the process the most of the most of its best suited advertising is a jewelry, which design a time consuming overloading, and the following the most of that is a small, or that we need to find its use a great for your brand launch and need more advanced artificial intelligence, or need to help and you can


100%|██████████| 1013/1013 [02:14<00:00,  7.52it/s]


input: What is your favorite book and why?
out: My husband's a simple, or build a big, I'm no longer than a website and other organizations such a child care, but I can be


100%|██████████| 1012/1012 [02:27<00:00,  6.88it/s]


input: How do you make a cup of tea?
out: Making "A romantic shopping and can be


 43%|████▎     | 437/1014 [00:56<01:14,  7.71it/s]


input: What is the weather like today?
out: According to make sure the task that in a video. However, there is a single, and would you can be


 29%|██▉       | 294/1013 [00:37<01:32,  7.78it/s]


input: Can you explain the concept of time?
out: Sure, but not beet, but is a good work can be


100%|██████████| 1014/1014 [02:12<00:00,  7.63it/s]


input: Can you tell me a joke?
out: I aming a popular platforms that task management system is another technique and I am that is a specific situations with that's a new skills and it is based on the time and the main tasks that's a great and the system engineering capabilities of the answer your destination, as you are several aspects of the human intelligence is always involves several aspects and should


100%|██████████| 1011/1011 [02:12<00:00,  7.65it/s]


input: Tell me the history of Civil War in the US
out: The sentiment analysis and use the person, so it is a successful advertising campaigns and "s, but while the product, and new languages and to start by the company activities like or the algorithm has been


## 2. batch inference (3 points)

В этом пункте вам нужно реализовать батчевый инференс. Подсказка: нужно уделить особое внимание тому, как составляется батч из нескольких входных примеров, и как сделать так, чтобы паддинги не учитывались моделью при генерации.

Заполните пропуски в функции `collate`, убедитесь, что ваша реализация kv-cache не сломалась, код работает корректно и результаты генерации получаются адекватные.

In [48]:
config.batch_size = 4

texts = [
    "What is your name?",
    "How are you feeling today?",
    "Can you describe your surroundings?",
    "What did you do yesterday?",
    "What is your favorite book and why?",
    "How do you make a cup of tea?",
    "What is the weather like today?",
    "Can you explain the concept of time?",
    "Can you tell me a joke?",
    "Tell me the history of Civil War in the US"
]

audios = []
for batch_id in range(math.ceil(len(texts) / config.batch_size)):
    input_ids, speech_input_ids = [], []
    ids = range(batch_id * config.batch_size, min(len(texts), (batch_id + 1) * config.batch_size))
    for idx in ids:
        input_ids_item, _, speech_input_ids_item = get_input_ids(text_tokenizer, config, texts[idx])
        input_ids.append(input_ids_item)
        speech_input_ids.append(speech_input_ids_item)
        
    input_ids, _, speech_input_ids = collate(input_ids, config.text_pad, speech_input_ids, config.audio_pad)
    
    model.set_kv_cache(batch_size=config.batch_size, device=config.device)
    
    tokens_t, tokens_a = generate(
        model,
        input_ids.to(config.device),
        speech_input_ids.to(config.device),
        max_returned_tokens=config.max_seq_length,
        temperature=config.temperature,
        top_k=config.top_k,
        top_p=config.top_p,
        eos_id_t=config.text_eos,
        eos_id_a=config.audio_eos,
        pad_id_t=config.text_pad,
        pad_id_a=config.audio_pad,
        use_kv_cache=True
    )
    tokens_a, lengths = postprocess_speech_codes(tokens_a, config.audio_vocabsize, [1] * (config.n_q - 1))
    save_paths = [(out_dir / f"2_{idx:02d}.wav") for idx in ids]
    audios.extend(decode_output([texts[idx] for idx in ids], tokens_t, tokens_a, lengths, save_paths))

    model.clear_kv_cache()
    
# for audio in audios:
display(Audio(audios[0], rate=24000))

100%|██████████| 1015/1015 [03:54<00:00,  4.33it/s]


input: What is your name?
out: My strategies conquered
input: How are you feeling today?
out: I consider the process the potential risks, as the original, while this can beed.
input: Can you describe your surroundings?
out: I'm using the world search engine rankings of your website usage tools and is a chatbots and use natural language-based learning and on the data analysis of the book, and you should beo.
input: What did you do yesterday?
out: Yesterday, but the structure or provide a different countries that key points and take advantage.


100%|██████████| 1012/1012 [03:57<00:00,  4.25it/s]


input: What is your favorite book and why?
out: My
input: How do you make a cup of tea?
out: To show that is a high cost-effectically, so you can be curious and the 2. 1. However, I can be
input: What is the weather like today?
out: According attitudes Cuban
input: Can you explain the concept of time?
out: Time


100%|██████████| 1011/1011 [03:34<00:00,  4.71it/s]


input: Can you tell me a joke?
out: I
input: Tell me the history of Civil War in the US
out: The context and that you should be


## 3. text answer forcing (4 points)

В этом задании вам нужно добиться того, чтобы модель произносила заранее написанный текст, а не генерировала текст на ходу. То есть заставить её работать в режиме синтеза.

Дополните функцию `collate`, заполните пропуск в функции `generate`, убедитесь что модель способна работать в режиме синтеза, добавьте свои примеры.

In [49]:
config.batch_size = 4

text_answers = ["Some very beatiful text voiced by a stupid model", "The majestic toaster hummed a quiet symphony of forgotten melodies while floating above a sea of confused calculators."]
texts = ["Write arbitrary text."] * len(text_answers)

audios = []
for step in range(math.ceil(len(texts) / config.batch_size)):
    input_ids, forced_input_ids, speech_input_ids = [], [], []
    ids = range(step * config.batch_size, min(len(texts), (step + 1) * config.batch_size))
    for idx in ids:
        input_ids_item, forced_input_ids_item, speech_input_ids_item = get_input_ids(text_tokenizer, config, texts[idx], text_answers[idx])
        input_ids.append(input_ids_item)
        forced_input_ids.append(forced_input_ids_item)
        speech_input_ids.append(speech_input_ids_item)
        
    input_ids, forced_input_ids, speech_input_ids = collate(input_ids, config.text_pad, speech_input_ids, config.audio_pad, forced_input_ids)
    
    model.set_kv_cache(batch_size=config.batch_size, device=config.device)
    
    tokens_t, tokens_a = generate(
        model,
        input_ids.to(config.device),
        speech_input_ids.to(config.device),
        forced_input_ids=forced_input_ids.to(config.device),
        max_returned_tokens=config.max_seq_length,
        temperature=config.temperature,
        top_k=config.top_k,
        top_p=config.top_p,
        eos_id_t=config.text_eos,
        eos_id_a=config.audio_eos,
        pad_id_t=config.text_pad,
        pad_id_a=config.audio_pad,
        use_kv_cache=True,
    )
    tokens_a, lengths = postprocess_speech_codes(tokens_a, config.audio_vocabsize, [1] * (config.n_q - 1))
    save_paths = [(out_dir / f"4_{i:02d}.wav") for i in ids]
    audios.extend(decode_output([texts[i] for i in ids], tokens_t, tokens_a, lengths, save_paths))

    model.clear_kv_cache()
    
for audio in audios:
    display(Audio(audios[0], rate=24000))

 61%|██████    | 619/1017 [02:06<01:21,  4.90it/s]


input: Write arbitrary text.
out: Some very beatiful text voiced by a stupid model
input: Write arbitrary text.
out: The majestic toaster hummed a quiet symphony of forgotten melodies while floating above a sea of confused calculators.
