In [1]:
import os
import numpy as np
from tqdm.auto import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader

In [2]:
def compute_ctc_loss(criterion, model_output, label, num_items_in_batch=None):
    # model_output[0] : logits (N, T, C)
    # model_output[1] : predicted_ids (N, T)
    # model_output[2] : attention_lengths (N)
    # label['token_ids_asr'] : (N, S_asr)
    # label['attn_mask_asr'] : (N, S_asr)
    # label['token_ids_llm'] : (N, S_llm)
    # label['attn_mask_llm'] : (N, S_llm)
    # num_items_in_batch : add just to handle error in transformers

    log_probs = model_output[0].log_softmax(dim=-1)
    log_probs = log_probs.transpose(0, 1)   # (T, N, C)
    input_lengths = model_output[2]
    targets = label['token_ids_llm']
    target_lengths = label['attn_mask_llm'].sum(dim=-1)
    # print(log_probs.shape, input_lengths.shape, targets.shape, target_lengths.shape)

    # criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    loss = criterion(log_probs, targets, input_lengths, target_lengths)
    print(loss.item())

    return loss

In [3]:
from utils import set_huggingface_cache_dir
from dataset_asr import load_asr_dataset, DATASET_ARGS
from dataloader_asr import collate_fn_asr2llm
from model import Wav2Vec2Mistral
from transformers import AutoModel

########## HYPERPARAMETERS ##########
cache_dir = "/data/yoom618/datasets/"
dataset_name = "ami"
batch_size = 4

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "0,1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# asr_model_name = "facebook/wav2vec2-base-960h"
asr_model_name = "facebook/wav2vec2-base"
# asr_model_name = "openai/whisper-small"

# llm_model_name = "openai-community/gpt2"
# llm_model_name = "mistralai/Mistral-Nemo-Instruct-2407"
llm_model_name = "mistralai/Mistral-7B-v0.1"

#####################################


# Set huggingface cache directory
token = set_huggingface_cache_dir(cache_dir)

# Load data
train_dataset = load_asr_dataset(
    name=dataset_name,
    phase = DATASET_ARGS[dataset_name]['phase']['train'],
    cache_dir=cache_dir,
    token=token
)

valid_dataset = load_asr_dataset(
    name=dataset_name,
    phase = DATASET_ARGS[dataset_name]['phase']['valid'],
    cache_dir=cache_dir,
    token=token
)

collate_fn = collate_fn_asr2llm(
    asr_model_name=asr_model_name,
    llm_model_name=llm_model_name,
    cache_dir=cache_dir,
    token=token
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
)


Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.


In [4]:
model_asr = AutoModel.from_pretrained(asr_model_name,
                                        cache_dir=cache_dir,
                                        token=token)
model_llm = AutoModel.from_pretrained(llm_model_name,
                                        cache_dir=cache_dir,
                                        token=token)
model = Wav2Vec2Mistral(model_asr, model_llm.embed_tokens, model_llm.rotary_emb, llm_input_dim=4096).to(DEVICE)
print(model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Wav2Vec2Mistral(
  (model_asr): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encode

In [5]:
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [6]:
model.to(DEVICE)

train_loss = []
model.train()
for data in tqdm(train_dataloader):
    label = data.pop('labels')
    data = {key: value.to(DEVICE) for key, value in data.items()}
    model_output = model(**data)
    loss = compute_ctc_loss(criterion, model_output, label)
    train_loss.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


valid_loss, valid_token_pred = [], []
model.eval()
for data in tqdm(valid_dataloader):
    label = data.pop('labels')
    data = {key: value.to(DEVICE) for key, value in data.items()}
    model_output = model(**data)
    loss = compute_ctc_loss(criterion, model_output, label)
    valid_loss.append(loss.item())
    valid_token_pred.append(model_output[1].numpy())

print("Train Loss: ", sum(train_loss) / len(train_loss))
print("Valid Loss: ", sum(valid_loss) / len(valid_loss))
print("Valid Token Prediction: ", np.concatenate(valid_token_pred, axis=0))


    

  0%|          | 0/27126 [00:01<?, ?it/s]

181.75173950195312
157.15696716308594
80.63188171386719
95.6749496459961
74.34232330322266
57.96189880371094
72.57085418701172
63.47569274902344
60.87736511230469
34.5848274230957
44.82511901855469
48.787681579589844
65.99227905273438
33.614051818847656
21.38821029663086
13.770570755004883
57.432456970214844
17.56890296936035
42.38812255859375
11.537582397460938
10.889054298400879
12.228043556213379
10.47781753540039
11.48092269897461
10.955398559570312
10.758767127990723
11.163269996643066
14.300612449645996
12.389555931091309
11.063733100891113
19.43952751159668
12.337018966674805
11.166667938232422
11.866815567016602
15.191755294799805
11.594145774841309
12.876094818115234
11.065353393554688
11.106584548950195
10.27889347076416
11.923223495483398
11.217575073242188
9.730085372924805
10.713858604431152
11.055875778198242
10.367107391357422
11.39326286315918
11.253965377807617
9.921747207641602
13.638691902160645
16.990215301513672
11.419896125793457
10.640040397644043
10.375778198242

  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


8.600448608398438
6.924631118774414
8.441055297851562
5.995739936828613
7.982108116149902
6.869657516479492
8.874988555908203
6.652193546295166
13.521246910095215
9.133332252502441
7.123584747314453
8.585320472717285
9.299453735351562
8.178258895874023
7.384350299835205
9.017740249633789
7.936483383178711
9.40986442565918
7.106295108795166
7.865551948547363
8.432065963745117
9.070377349853516
6.451022148132324
6.679896831512451
5.509756088256836
8.042380332946777
7.818629741668701
7.013092994689941
8.632495880126953
7.386765480041504
9.124058723449707
8.818490982055664
6.72771692276001
5.766145706176758
6.189537048339844
9.285337448120117
8.233963012695312
7.138147354125977
9.0072021484375
7.07771110534668
7.208408355712891
7.299932956695557
7.659989833831787
9.206974983215332
8.03761100769043
7.326525688171387
7.283721923828125
8.111818313598633
8.685813903808594
6.5884108543396
8.254240036010742
7.8031158447265625
8.37103271484375
6.732024192810059
6.287527084350586
6.665379524230957

  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


7.961210250854492
7.001787185668945
8.069929122924805
9.145803451538086
6.714911460876465
6.468033790588379
8.876129150390625
8.41215705871582
5.617740154266357
9.666565895080566
9.202054023742676
6.376548767089844
6.741274356842041
9.173006057739258
6.608522891998291
9.275404930114746
4.298916816711426
8.647104263305664
6.943261623382568
6.615594863891602
6.920285224914551
10.109611511230469
5.780860900878906
6.876651763916016
8.118955612182617
7.674873352050781
9.098281860351562
6.332553386688232
8.889776229858398
6.002321243286133
6.098055362701416
7.050077438354492
6.30561637878418
7.426987648010254
6.061684608459473
7.200396537780762
6.787937164306641
6.752638339996338
5.217319488525391
9.141437530517578
7.77820348739624
5.327518463134766
7.006859302520752
5.29393196105957
7.360318183898926
5.390995979309082
8.61817455291748
6.450012683868408
7.547101020812988
7.0593671798706055
7.068987846374512
8.629310607910156
8.372232437133789
6.262355804443359
7.058034420013428
6.87584400177

  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  ret = ret.dtype.type(ret / rcount)
  normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


6.553267002105713
7.328364372253418
7.161549091339111
6.9769673347473145
6.30936336517334
7.170289993286133
6.410453796386719
8.711502075195312
7.358915328979492
5.559826374053955
8.61307144165039
7.680514812469482
8.344284057617188
7.23724365234375
7.597719192504883
7.908638000488281
8.004256248474121
7.624660491943359
8.272897720336914
6.885980606079102
8.148101806640625
6.822929382324219
7.870436191558838
6.741250514984131
9.546045303344727
7.511861324310303
5.417781829833984
9.051826477050781
8.292929649353027
5.553046226501465
6.3546953201293945
7.1984124183654785
7.478451728820801
7.191933631896973
7.269780158996582
8.83438491821289
7.564786911010742
4.428189277648926
6.308108329772949
6.520620822906494
6.714132308959961
7.315393447875977
7.841290473937988
7.792569160461426
6.501743316650391
7.150387287139893
5.818255424499512
6.012389183044434
6.937349796295166
8.306896209716797
7.149257183074951
9.942211151123047
7.36024808883667
8.627764701843262
8.862119674682617
7.5532240867

  0%|          | 0/3275 [00:01<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fbeb3a0eac0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fbeb3a0eac0>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/yoom618/.local/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/home/yoom618/.local/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()Exception ignored in:     
self._shutdown_workers()  File "/home/yoom618/.local/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fbeb3a0eac0>

  File "/home/yoom618/.local/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    Traceback (most recent call last):
if w.is_alive():  File "/home/yoom618/.local/lib/python3.12/site-packages/torch/utils/data/

9.414063453674316


TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [10]:
valid_loss, valid_token_pred = [], []
model.eval()
for data in tqdm(valid_dataloader):
    label = data.pop('labels')
    data = {key: value.to(DEVICE) for key, value in data.items()}
    model_output = model(**data)
    loss = compute_ctc_loss(criterion, model_output, label)
    valid_loss.append(loss.item())
    print(model_output[1])
    # valid_token_pred.append(model_output[1].numpy())
    break

  0%|          | 0/3275 [00:01<?, ?it/s]

ModuleNotFoundError: Caught ModuleNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/yoom618/.local/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/home/yoom618/ASR-w-GPT/src/dataset_asr.py", line 42, in __getitem__
    return self.dataset[index][self.audio_column], self.dataset[index][self.text_column]
           ~~~~~~~~~~~~^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/arrow_dataset.py", line 2780, in __getitem__
    return self._getitem(key)
           ^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/arrow_dataset.py", line 2765, in _getitem
    formatted_output = format_table(
                       ^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/formatting/formatting.py", line 639, in format_table
    return formatter(pa_table, query_type=query_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/formatting/formatting.py", line 403, in __call__
    return self.format_row(pa_table)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/formatting/formatting.py", line 444, in format_row
    row = self.python_features_decoder.decode_row(row)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/formatting/formatting.py", line 222, in decode_row
    return self.features.decode_example(row) if self.features else row
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/features/features.py", line 2046, in decode_example
    column_name: decode_nested_example(feature, value, token_per_repo_id=token_per_repo_id)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/features/features.py", line 1404, in decode_nested_example
    return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) if obj is not None else None
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/datasets/features/audio.py", line 191, in decode_example
    array = librosa.to_mono(array)
            ^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/lazy_loader/__init__.py", line 83, in __getattr__
    attr = getattr(submod, name)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.local/lib/python3.12/site-packages/lazy_loader/__init__.py", line 82, in __getattr__
    submod = importlib.import_module(submod_path)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yoom618/.conda/envs/py312/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1387, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 999, in exec_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "/home/yoom618/.local/lib/python3.12/site-packages/librosa/core/audio.py", line 14, in <module>
    import scipy.signal
ModuleNotFoundError: No module named 'scipy.signal'


In [9]:
import scipy