# Whisper Fairseq Adaptation

In [None]:
!pip install torchinfo fairseq transformers huggingface-hub

In [2]:
import torch
import torch.nn as nn
from fairseq.models import FairseqEncoder, FairseqDecoder, FairseqEncoderDecoderModel, register_model
from transformers.models.whisper.modeling_whisper import WhisperPositionalEmbedding, WhisperSdpaAttention
from transformers.activations import GELUActivation
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from fairseq.data import Dictionary
import torchaudio
from datasets import load_dataset
from transformers import AutoProcessor
from torchaudio.transforms import MelSpectrogram
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

In [3]:
# 1. Загружаем Whisper с HF

processor = AutoProcessor.from_pretrained("openai/whisper-large-v3-turbo")
w = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3-turbo")

In [4]:
# 2. Сохраняем веса модели

torch.save(w.state_dict(), "whisper_large_v3_turbo_weights.pth")

In [5]:
# 3. Архитектура модели для Fairseq

class WhisperEncoder(FairseqEncoder):
    def __init__(self, dictionary, embed_dim=1280, num_layers=32):
        super().__init__(dictionary)
        self.conv1 = nn.Conv1d(128, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
        self.embed_positions = nn.Embedding(1500, embed_dim)
        self.layers = nn.ModuleList([
            WhisperEncoderLayer(embed_dim) for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.conv1(x)
        #print("self.conv1(x)", x.shape) 
        
        x = self.conv2(x)
        #print("self.conv2(x)", x.shape) 
        
        seq_length = x.size(2)
    
        positions = torch.arange(seq_length, device=x.device).unsqueeze(0) 
        #print("positions", positions.shape)  # Should be [1, seq_length2]
    
        ep = self.embed_positions(positions)  # Shape: [1, seq_length2, embed_dim]
        
        x = x.permute(0, 2, 1) 
        x = x + ep  
    
        for layer in self.layers:
            x = layer(x)
    
        x = self.layer_norm(x)
        return x


# Define WhisperEncoderLayer

class WhisperEncoderLayer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.self_attn = WhisperAttention(embed_dim)
        self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
        self.activation_fn = GELUActivation()
        self.fc1 = nn.Linear(embed_dim, 5120)
        self.fc2 = nn.Linear(5120, embed_dim)
        self.final_layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        residual = x
        x = self.self_attn(x)
        x = self.self_attn_layer_norm(x + residual)

        residual = x
        x = self.fc1(x)
        x = self.activation_fn(x)
        x = self.fc2(x)
        x = self.final_layer_norm(x + residual)
        return x


class WhisperAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(x.size(-1))
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        return self.out_proj(attn_output)


class WhisperDecoder(FairseqDecoder):
    def __init__(self, dictionary, embed_dim=1280, num_layers=4):
        super().__init__(dictionary)
        self.embed_tokens = nn.Embedding(len(dictionary), embed_dim, padding_idx=50257)
        self.embed_positions = WhisperPositionalEmbedding(448, embed_dim)

        self.layers = nn.ModuleList([
            WhisperDecoderLayer(embed_dim) for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x, encoder_out):
        x = x.long()
        #print("x", x.shape)
        
        batch_size = x.size()[0]
        seq_length = x.size()[1] 
        positions = torch.arange(seq_length, device=x.device).unsqueeze(0).expand(batch_size, -1)
    
        #print("positions", positions.shape)
        et = self.embed_tokens(x)
        ep = self.embed_positions(positions)
    
        ep = ep.unsqueeze(1).expand(-1, 3000, -1)  # Adjusted this line
        #print("et", et.shape)
        #print("ep", ep.shape)
    
        # Embed tokens and positions
        x = et + ep
    
        #print("self.embed_tokens(x) + self.embed_positions(positions)", x.shape)
        
        # Pass through the decoder layers
        for layer in self.layers:
            x = layer(x, encoder_out)
    
        # Apply layer normalization
        x = self.layer_norm(x)
        
        return x

class WhisperDecoderLayer(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.self_attn = WhisperAttention(embed_dim)
        self.activation_fn = GELUActivation()
        self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
        self.encoder_attn = WhisperAttention(embed_dim)
        self.encoder_attn_layer_norm = nn.LayerNorm(embed_dim)
        self.fc1 = nn.Linear(embed_dim, 5120)
        self.fc2 = nn.Linear(5120, embed_dim)
        self.final_layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x, encoder_out):
        residual = x
        x = self.self_attn(x)
        x = self.self_attn_layer_norm(x + residual)

        residual = x
        x = self.encoder_attn(x, encoder_out)
        x = self.encoder_attn_layer_norm(x + residual)

        residual = x
        x = self.fc1(x)
        x = self.activation_fn(x)
        x = self.fc2(x)
        x = self.final_layer_norm(x + residual)
        return x


class WhisperForConditionalGeneration(nn.Module):
    def __init__(self, dictionary, embed_dim=1280, encoder_layers=32, decoder_layers=4):
        super().__init__()
        self.model = WhisperModel(dictionary, embed_dim=embed_dim, encoder_layers=encoder_layers, decoder_layers=decoder_layers)
        self.proj_out = nn.Linear(in_features=1280, out_features=len(dictionary), bias=False)

    def forward(self, src_tokens, prev_output_tokens):
        encoder_out = self.model.encoder(src_tokens)
        decoder_out = self.model.decoder(prev_output_tokens, encoder_out=encoder_out)
        logits = self.proj_out(decoder_out)
        return logits


@register_model('whisper-large-v3-turbo')
class WhisperModel(FairseqEncoderDecoderModel):
    def __init__(self, dictionary, embed_dim=1280, encoder_layers=32, decoder_layers=4):
        encoder = WhisperEncoder(dictionary, embed_dim=embed_dim, num_layers=encoder_layers)
        decoder = WhisperDecoder(dictionary, embed_dim=embed_dim, num_layers=decoder_layers)
        super().__init__(encoder, decoder)

    @classmethod
    def build_model(cls, args, task):
        dictionary = task.source_dictionary
        return cls(dictionary=dictionary)

In [6]:
# 4. Перенесем словарь токенов

fairseq_dictionary = Dictionary()
fairseq_dictionary.symbols = []
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()

for token in vocab.keys():
    fairseq_dictionary.add_symbol(token)

print(f"Number of tokens in Fairseq dictionary: {len(fairseq_dictionary)}")

Number of tokens in Fairseq dictionary: 51866


In [7]:
# 5. Создаем модель

model = WhisperForConditionalGeneration(dictionary=fairseq_dictionary)

model

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1280)
      (layers): ModuleList(
        (0-31): 32 x WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bias=Tr

In [8]:
# 6. Загружаем сохраненные веса
state_dict = torch.load("whisper_large_v3_turbo_weights.pth", weights_only=True)

model.load_state_dict(state_dict)

<All keys matched successfully>

In [10]:
# 7. Переносим на GPU, если оно есть

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

model.to(device)

Using device: cuda


WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 1280)
      (layers): ModuleList(
        (0-31): 32 x WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
            (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1280, out_features=5120, bias=True)
          (fc2): Linear(in_features=5120, out_features=1280, bias=Tr