In [20]:
import argparse
import random
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import whisper
from tqdm import tqdm
from whisper.tokenizer import get_tokenizer
# from data_utils.dataloader import get_dataloader
# from model.prompting import Prompting, Prompting_len0

import os
import loralib as lora

tokenizer = get_tokenizer(multilingual=True, task="transcribe")
model = whisper.load_model("large-v3", 'cuda')
#  -1 is for the special token `sot_prev` and the other half is for the transcribed tokens
max_prompt_length = model.dims.n_text_ctx // 2 - 1

In [21]:
model

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-31): 32 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=1280, out_features=1280, bias=True)
          (key): Linear(in_features=1280, out_features=1280, bias=False)
          (value): Linear(in_features=1280, out_features=1280, bias=True)
          (out): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (attn_ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=1280, out_features=5120, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=5120, out_features=1280, bias=True)
        )
        (mlp_ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm(

In [9]:
# freeze the whole whisper model
for p in model.parameters():
    p.requires_grad = False

In [37]:
class System(nn.Module):
    def __init__(self, model, prompt):
        super(System, self).__init__()
        self.prompt = prompt
        self.model = model
        self.emb, self.hook = self.install_forward_hook()

    def forward(self, x, xvec, y_in):
        self.emb[self.model.encoder.conv2], _ = self.prompt(xvec)
        logits = self.model.decoder(y_in, self.model.encoder(x))
        return logits

    def install_forward_hook(self):
        hooks = []
        weight = {}
        layer = self.model.encoder.conv2

        def cln_func(module, _, output):
            output[:, :, 0] = weight[module]
            return output
        hooks.append(layer.register_forward_hook(cln_func))
        return weight, hooks

def set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)


def process_model(model):
    load_layer = ['query', 'value']
    for module in model.named_modules():
        if any(c in module[0] for c in load_layer):
            # print(module[0])
            # print()
            lora_layer = lora.Linear(module[1].in_features, module[1].out_features, r=4,
                                     bias=hasattr(module[1], 'bias'), merge_weights=False)
            lora_layer.weight = module[1].weight
            if hasattr(module[1], 'bias'):
                lora_layer.bias = module[1].bias
            set_module(model, module[0], lora_layer)
    return None

In [38]:
process_model(model)

encoder.blocks.0.attn.query

encoder.blocks.0.attn.value

encoder.blocks.1.attn.query

encoder.blocks.1.attn.value

encoder.blocks.2.attn.query

encoder.blocks.2.attn.value

encoder.blocks.3.attn.query

encoder.blocks.3.attn.value

encoder.blocks.4.attn.query

encoder.blocks.4.attn.value

encoder.blocks.5.attn.query

encoder.blocks.5.attn.value

encoder.blocks.6.attn.query

encoder.blocks.6.attn.value

encoder.blocks.7.attn.query

encoder.blocks.7.attn.value

encoder.blocks.8.attn.query

encoder.blocks.8.attn.value

encoder.blocks.9.attn.query

encoder.blocks.9.attn.value

encoder.blocks.10.attn.query

encoder.blocks.10.attn.value

encoder.blocks.11.attn.query

encoder.blocks.11.attn.value

encoder.blocks.12.attn.query

encoder.blocks.12.attn.value

encoder.blocks.13.attn.query

encoder.blocks.13.attn.value

encoder.blocks.14.attn.query

encoder.blocks.14.attn.value

encoder.blocks.15.attn.query

encoder.blocks.15.attn.value

encoder.blocks.16.attn.query

encoder.blocks.16.attn.value



In [24]:
model

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-31): 32 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=1280, out_features=1280, bias=True)
          (key): Linear(in_features=1280, out_features=1280, bias=False)
          (value): Linear(in_features=1280, out_features=1280, bias=True)
          (out): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (attn_ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=1280, out_features=5120, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=5120, out_features=1280, bias=True)
        )
        (mlp_ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm(

In [25]:
lora.mark_only_lora_as_trainable(model, bias='lora_only')

In [26]:
model

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-31): 32 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=1280, out_features=1280, bias=True)
          (key): Linear(in_features=1280, out_features=1280, bias=False)
          (value): Linear(in_features=1280, out_features=1280, bias=True)
          (out): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (attn_ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=1280, out_features=5120, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=5120, out_features=1280, bias=True)
        )
        (mlp_ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm(

In [40]:
for name,param in model.named_parameters():
    print(name,param.requires_grad)

encoder.conv1.weight False
encoder.conv1.bias False
encoder.conv2.weight False
encoder.conv2.bias False
encoder.blocks.0.attn.query.weight False
encoder.blocks.0.attn.query.bias True
encoder.blocks.0.attn.query.lora_A True
encoder.blocks.0.attn.query.lora_B True
encoder.blocks.0.attn.key.weight False
encoder.blocks.0.attn.value.weight False
encoder.blocks.0.attn.value.bias True
encoder.blocks.0.attn.value.lora_A True
encoder.blocks.0.attn.value.lora_B True
encoder.blocks.0.attn.out.weight False
encoder.blocks.0.attn.out.bias False
encoder.blocks.0.attn_ln.weight False
encoder.blocks.0.attn_ln.bias False
encoder.blocks.0.mlp.0.weight False
encoder.blocks.0.mlp.0.bias False
encoder.blocks.0.mlp.2.weight False
encoder.blocks.0.mlp.2.bias False
encoder.blocks.0.mlp_ln.weight False
encoder.blocks.0.mlp_ln.bias False
encoder.blocks.1.attn.query.weight False
encoder.blocks.1.attn.query.bias True
encoder.blocks.1.attn.query.lora_A True
encoder.blocks.1.attn.query.lora_B True
encoder.blocks.1.a

In [27]:
total = sum([param.nelement() for param in model.parameters() if param.requires_grad])
print('Number of parameter: % .4fM' % (total / 1e6))

Number of parameter:  2.2118M


In [None]:
# system = System(model, prompt_layer)

# total = sum([param.nelement() for param in system.parameters() if param.requires_grad])
# print('Number of parameter: % .4fM' % (total / 1e6))
# optimizer = torch.optim.AdamW([param for param in system.parameters() if param.requires_grad], lr=args.lr,
#                                 )
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1, verbose=True)
