In [10]:
import os
import numpy as np
from tqdm.auto import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader

In [11]:
def compute_ctc_loss(criterion, model_output, label, num_items_in_batch=None):
    # model_output[0] : logits (N, T, C)
    # model_output[1] : predicted_ids (N, T)
    # model_output[2] : attention_lengths (N)
    # label['token_ids_asr'] : (N, S_asr)
    # label['attn_mask_asr'] : (N, S_asr)
    # label['token_ids_llm'] : (N, S_llm)
    # label['attn_mask_llm'] : (N, S_llm)
    # num_items_in_batch : add just to handle error in transformers

    log_probs = model_output[0].log_softmax(dim=-1)
    log_probs = log_probs.transpose(0, 1)   # (T, N, C)
    input_lengths = model_output[2]
    targets = label['token_ids_llm']
    target_lengths = label['attn_mask_llm'].sum(dim=-1)
    # print(log_probs.shape, input_lengths.shape, targets.shape, target_lengths.shape)

    # criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    loss = criterion(log_probs, targets, input_lengths, target_lengths)
    print(loss.item())

    return loss

In [12]:
from utils import set_huggingface_cache_dir
from dataset_asr import load_asr_dataset, DATASET_ARGS
from dataloader_asr import collate_fn_asr2llm
from model import Wav2Vec2Mistral
from transformers import AutoModel

########## HYPERPARAMETERS ##########
cache_dir = "/data/yoom618/datasets/"
dataset_name = "ami"
batch_size = 4

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "0,1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# asr_model_name = "facebook/wav2vec2-base-960h"
asr_model_name = "facebook/wav2vec2-base"
# asr_model_name = "openai/whisper-small"

# llm_model_name = "openai-community/gpt2"
# llm_model_name = "mistralai/Mistral-Nemo-Instruct-2407"
llm_model_name = "mistralai/Mistral-7B-v0.1"

#####################################


# Set huggingface cache directory
token = set_huggingface_cache_dir(cache_dir)

# Load data
train_dataset = load_asr_dataset(
    name=dataset_name,
    phase = DATASET_ARGS[dataset_name]['phase']['train'],
    cache_dir=cache_dir,
    token=token
)

valid_dataset = load_asr_dataset(
    name=dataset_name,
    phase = DATASET_ARGS[dataset_name]['phase']['valid'],
    cache_dir=cache_dir,
    token=token
)

collate_fn = collate_fn_asr2llm(
    asr_model_name=asr_model_name,
    llm_model_name=llm_model_name,
    cache_dir=cache_dir,
    token=token
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
)


Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.


In [13]:
model_asr = AutoModel.from_pretrained(asr_model_name,
                                        cache_dir=cache_dir,
                                        token=token)
model_llm = AutoModel.from_pretrained(llm_model_name,
                                        cache_dir=cache_dir,
                                        token=token)
model = Wav2Vec2Mistral(model_asr, model_llm.embed_tokens, model_llm.rotary_emb, llm_input_dim=4096).to(DEVICE)
print(model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Wav2Vec2Mistral(
  (model_asr): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encode

In [14]:
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
model.to(DEVICE)

train_loss = []
model.train()
for data in tqdm(train_dataloader):
    label = data.pop('labels')
    data = {key: value.to(DEVICE) for key, value in data.items()}
    model_output = model(**data)
    loss = compute_ctc_loss(criterion, model_output, label)
    train_loss.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


valid_loss, valid_token_pred = [], []
model.eval()
for data in tqdm(valid_dataloader):
    label = data.pop('labels')
    data = {key: value.to(DEVICE) for key, value in data.items()}
    model_output = model(**data)
    loss = compute_ctc_loss(criterion, model_output, label)
    valid_loss.append(loss.item())
    valid_token_pred.append(model_output[1].numpy())

print("Train Loss: ", sum(train_loss) / len(train_loss))
print("Valid Loss: ", sum(valid_loss) / len(valid_loss))
print("Valid Token Prediction: ", np.concatenate(valid_token_pred, axis=0))


    

  0%|          | 0/27126 [00:01<?, ?it/s]

8.143878936767578
7.798883438110352
10.131185531616211
7.731662750244141
9.169839859008789
9.935338020324707
7.520687580108643
9.116092681884766
8.526679992675781
9.10490608215332
9.533324241638184
9.737722396850586
10.041547775268555
7.03504753112793
9.836959838867188
8.165098190307617
9.363003730773926
9.999876976013184
8.339284896850586
7.425836563110352
8.091684341430664
8.56432819366455
7.287934303283691
4.999331474304199
5.656243801116943
6.819104194641113
8.506889343261719
10.37594223022461
8.260431289672852
7.660501480102539
7.101140022277832
4.639800548553467
7.259342193603516
6.52735710144043
8.722786903381348
8.475150108337402
7.658107757568359
7.721282005310059
10.833553314208984
9.408172607421875
9.787307739257812
11.940674781799316
10.463768005371094
9.632521629333496
5.545317649841309
10.480749130249023
7.604115962982178
7.781683444976807
8.556057929992676
6.093286514282227
8.784473419189453
7.136484146118164
9.42375659942627
6.420638084411621
9.601581573486328
10.110310