In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
# from .decoding import decode as decode_function
# from .decoding import detect_language as detect_language_function
# from .transcribe import transcribe as transcribe_function

import base64
import gzip

from transformers import (
    pipeline,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    Seq2SeqTrainer,
    TrainerCallback,
    Seq2SeqTrainingArguments,
    TrainerState,
    TrainerControl,
    TrainingArguments,
    BitsAndBytesConfig,
    WhisperTokenizer,
    WhisperFeatureExtractor,
    PushToHubCallback,
    AutoTokenizer,
    WhisperConfig,
    AutoFeatureExtractor,
    AutoProcessor,
    AutoModel,
)
device = 'cuda' #if torch.cuda.is_available() else 'cpu'

In [4]:


class LayerNorm(nn.LayerNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

class Linear(nn.Linear):
    def forward(self, x):
        return F.linear(x, self.weight.to(x.dtype),None if self.bias is None else self.bias.to(x.dtype),
        )

class Conv1d(nn.Conv1d): #https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
    def _conv_forward(self, x, weight, bias):
        return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))

def sinusoids(length, channels, max_timescale=10000):
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

In [5]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(Embedding, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
    def forward(self, x):
        out = self.embed(x)
        return out

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, config = None):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.config = config
        
        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.o = nn.Linear(embed_dim, embed_dim)
        self.scaling = self.head_dim**-0.5
    
    def _shape(self, tensor, seq_len, bsz):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
    
    def forward(self, Q, K, V, mask=None):  # Apply linear transformations and split heads           
        Q = self.split_heads(self.k(Q))
        K = self.split_heads(self.v(K))
        V = self.split_heads(self.q(V))
        output = self.scaled_dot_product_attention(Q, K, V, mask)   # Perform scaled dot-product attention  
        attn_output = self.o(self.combine_heads(output))  # Combine heads and apply output transformation
        return attn_output
   
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)  # Calculate attention scores
        if mask is not None: # Apply mask if provided (useful for preventing attention to certain parts like padding)
            attn_weights = attn_weights.masked_fill(mask == 0, -1e9)      
        probs = torch.softmax(attn_weights, dim=-1)  # Softmax is applied to obtain attention probabilities
        attn_probs = torch.matmul(probs, V) # Multiply by values to obtain the final output
        return attn_probs
        
    def split_heads(self, x):  # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, self.embed_dim = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        
    def combine_heads(self, x):  # Combine the multiple heads back to original shape
        batch_size, _, seq_length, self.head_dim = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.embed_dim)
        


In [None]:
class MultiHeadAttention_2(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.q = Linear(embed_dim, embed_dim)
        self.k = Linear(embed_dim, embed_dim, bias=False)
        self.v = Linear(embed_dim, embed_dim)
        self.o = Linear(embed_dim, embed_dim)

    def forward(self, x, xa = None, mask = None, kv_cache = None):
        Q = self.q(x)
        if kv_cache is None or xa is None or self.key not in kv_cache:
            K = self.k(x if xa is None else xa)
            V = self.v(x if xa is None else xa)
        else:
            K = kv_cache[self.k]
            V = kv_cache[self.v]
        wv, qk = self.qkv_attention(Q, K, V, mask)
        return self.out(wv), qk

    def qkv_attention(self, q, k, v, mask= None):
        n_batch, n_ctx, embed_dim = q.shape
        scale = (embed_dim // self.num_heads) ** -0.25
        q = q.view(*q.shape[:2], self.num_heads, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.num_heads, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.num_heads, -1).permute(0, 2, 1, 3)
        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()
        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()

class ResidualAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, cross_attention = False):
        
        super().__init__()
        self.attn = MultiHeadAttention(embed_dim, num_heads)
        self.attn_ln = LayerNorm(embed_dim)

        self.cross_attn = (MultiHeadAttention(embed_dim, num_heads) if cross_attention else None)
        self.cross_attn_ln = LayerNorm(embed_dim) if cross_attention else None

        n_mlp = embed_dim * 4
        self.mlp = nn.Sequential(Linear(embed_dim, n_mlp), nn.GELU(), Linear(n_mlp, embed_dim))
        self.mlp_ln = LayerNorm(embed_dim)

    def forward(
        self, x, xa = None, mask = None, kv_cache = None):
        
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x
    
class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels, n_ctx, embed_dim, num_heads, n_layer):
        
        super().__init__()
        self.conv1 = Conv1d(n_mels, embed_dim, kernel_size=3, padding=1)
        self.conv2 = Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, embed_dim))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(embed_dim, num_heads) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(embed_dim)

    def forward(self, x):

        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)
        x = (x + self.positional_embedding).to(x.dtype)

        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x
    
class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab, n_ctx, embed_dim, num_heads, n_layer):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, embed_dim)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, embed_dim))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(embed_dim, num_heads, cross_attention=True)
                for _ in range(n_layer)
            ])
        self.ln = LayerNorm(embed_dim)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x, xa, kv_cache = None):
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]])
        x = x.to(xa.dtype)
        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
        x = self.ln(x)
        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        return logits


In [None]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, embed_dim, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_dim, d_ff)
        self.fc2 = nn.Linear(d_ff, embed_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, embed_dim)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.feed_forward = PositionWiseFeedForward(embed_dim, d_ff)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads)
        self.feed_forward = PositionWiseFeedForward(embed_dim, d_ff)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self, dims, src_vocab_size, tgt_vocab_size, embed_dim, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )
    
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
        
        
        all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

        self.encoder_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(embed_dim, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(embed_dim, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(embed_dim, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def set_alignment_heads(self, dump):
        array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
        mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

    def embed_audio(self, mel):
        return self.encoder(mel)

    def logits(self, tokens, audio_features):
        return self.decoder(tokens, audio_features)

    def forward(self, mel, tokens):
        return self.decoder(tokens, self.encoder(mel))
    
        @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        return self.dims.n_vocab >= 51865

    @property
    def num_languages(self):
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):

        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [None]:
src_vocab_size = 5000
tgt_vocab_size = 6000
embed_dim = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, embed_dim, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()
torch.save(transformer.state_dict(), "my_model")


In [None]:
for epoch in range(4):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

In [None]:
transformer.eval()

# Generate random sample validation data
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

with torch.no_grad():

    val_output = transformer(val_src_data, val_tgt_data[:, :-1])
    val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))
    print(f"Validation Loss: {val_loss.item()}")

In [None]:
# @dataclass
# class ModelDimensions:
#     n_mels: int
#     n_audio_ctx: int
#     n_audio_state: int
#     n_audio_head: int
#     n_audio_layer: int
#     n_vocab: int
#     n_text_ctx: int
#     n_text_state: int
#     n_text_head: int
#     n_text_layer: int

# trainer = Seq2SeqTrainer(
#     args=training_args,
#     model=transformer,
#     train_dataset=vectorized_dataset, # ["train"],
#     eval_dataset=vectorized_dataset_test, # ["test"],
#     data_collator=data_collator,
#     tokenizer=processor.feature_extractor,
#     callbacks=[SavePeftModelCallback(),ShuffleCallback()],
#     compute_metrics=compute_metrics, 
#     )

# from dataclasses import dataclass
# from typing import List, Optional, Any, Dict, List, Union
# model_name_or_path = "./1"
# language = "japanese"
# task = "transcribe"

# def prepare_dataset(batch):
#     audio = batch["audio"]
#     batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
#     batch["audio_length"] = len(audio["array"]) / audio["sampling_rate"]
#     batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
#     return batch

# @dataclass
# class DataCollatorSpeechSeq2SeqWithPadding:
#     processor: Any

#     def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
#         input_features = [{"input_features": feature["input_features"]} for feature in features]
#         batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
#         label_features = [{"input_ids": feature["labels"]} for feature in features]
#         labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
#         labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
#         if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
#             labels = labels[:, 1:]
#         batch["labels"] = labels
#         return batch
    
# feature_extractor = WhisperFeatureExtractor.from_pretrained(
#     model_name_or_path,
#     do_normalize = False,
#     device="cuda",
#     sampling_rate=16000,
#     return_attention_mask=True,
#     truncation=True,
#     n_fft=1024,
#     n_mels=512,
#     hop_length=320,
#     pad_mode="reflect",
#     power=2.0,
#     norm="slaney",
#     mel_scale="htk",
#     )
# tokenizer = WhisperTokenizer.from_pretrained(
#     model_name_or_path,
#     language=language,
#     task=task,
#     )
# processor = WhisperProcessor.from_pretrained(
#     model_name_or_path,
#     tokenizer=tokenizer,
#     feature_extractor=feature_extractor,
#     )


In [None]:
# class MultiheadAttention(nn.Module):
#     def __init__(self, dmodel, dk, dv, num_heads):
#         super().__init__()
#         self.num_heads = num_heads
#         self.dmodel = dmodel

#         self.proj_q, self.bias_q = self._get_proj_bias(dk)
#         self.proj_k, self.bias_k = self._get_proj_bias(dk)
#         self.proj_v, self.bias_v = self._get_proj_bias(dv)
        
#         self.output_proj = nn.Linear(dv * num_heads, dmodel, bias=False)

#         self.register_buffer('scale', torch.tensor(dk, dtype=float).sqrt())
    
#     def _get_proj_bias(self, hidsize):
#         proj = nn.Parameter(torch.Tensor(self.num_heads, self.dmodel, hidsize))
#         bias = nn.Parameter(torch.Tensor(1, self.num_heads, 1, hidsize))
#         nn.init.xavier_uniform_(proj)
#         nn.init.constant_(bias, 0.)
#         return proj, bias

#     def forward(self, q, k, v):
#         # batch, seqlen, dmodel
#         q = (q.unsqueeze(1) @ self.proj_q) + self.bias_q
#         k = (k.unsqueeze(1) @ self.proj_k) + self.bias_k
#         v = (v.unsqueeze(1) @ self.proj_v) + self.bias_v
#         # batch, head, seqlen, dk|dv

#         q, k, v = q.unsqueeze(3), k.unsqueeze(2), v.unsqueeze(2)
#         # q: (batch, head, qlen, 1, dk)
#         # k, v: (batch, head, 1, kvlen, dk|dv)
#         logits = (q * k / self.scale).sum(-1, keepdim=True)
#         # batch, head, qlen, kvlen, 1
#         weighted_v = F.softmax(logits, -2) * v
#         # batch, head, qlen, kvlen, dv
#         heads = weighted_v.sum(-2)
#         # batch, head, qlen, dv
#         hid = torch.cat(heads.unbind(1), -1)
#         # batch, qlen, dv * head
#         output = self.output_proj(hid)
#         # batch, qlen, dmodel
#         return output