In [1]:
# %pip install -r requirements.txt --force-reinstall

In [2]:
!export HF_HOME="/data/yoom618/datasets"

import os
os.environ['HF_HOME'] = "/data/yoom618/datasets"

!huggingface-cli whoami

yoom618


## 데이터셋 로딩

In [3]:
import os
from datasets import load_dataset
from pprint import pprint
from IPython.display import display, Audio

ami = load_dataset("esb/datasets", "ami",
                   cache_dir="/data/yoom618/datasets/",
                   trust_remote_code=True)

print(ami)

display(Audio(ami['train'][0]['audio']['path'], autoplay=False))
pprint(ami['train'][0])

DatasetDict({
    train: Dataset({
        features: ['audio', 'dataset', 'text', 'id'],
        num_rows: 108502
    })
    validation: Dataset({
        features: ['audio', 'dataset', 'text', 'id'],
        num_rows: 13098
    })
    test: Dataset({
        features: ['audio', 'dataset', 'text', 'id'],
        num_rows: 12643
    })
})


{'audio': {'array': array([ 0.00231934, -0.00183105, -0.00543213, ..., -0.00238037,
       -0.00244141, -0.00219727]),
           'path': '/data/yoom618/datasets/downloads/extracted/eddfd195ad3d12db3d98d60332bc890f00ce2875630cd1f43249855e6c44c142/EN2001a/train_ami_en2001a_h04_meo069_0330297_0330718.wav',
           'sampling_rate': 16000},
 'dataset': 'ami',
 'id': 'AMI_EN2001a_H04_MEO069_0330297_0330718',
         'nothing at all in the gateway machine.'}


In [4]:
# remove columns

remain_columns = ['audio', 'text']
for phase in ami.keys():
    ami[phase] = ami[phase].remove_columns([col for col in ami[phase].column_names if col not in remain_columns])

pprint(ami['train'][0])

{'audio': {'array': array([ 0.00231934, -0.00183105, -0.00543213, ..., -0.00238037,
       -0.00244141, -0.00219727]),
           'path': '/data/yoom618/datasets/downloads/extracted/eddfd195ad3d12db3d98d60332bc890f00ce2875630cd1f43249855e6c44c142/EN2001a/train_ami_en2001a_h04_meo069_0330297_0330718.wav',
           'sampling_rate': 16000},
         'nothing at all in the gateway machine.'}


In [5]:
import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2Processor, AutoTokenizer

BATCH_SIZE = 4
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor_asr = wav2vec_processor.feature_extractor
tokenizer_asr = wav2vec_processor.tokenizer
tokenizer_llm = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Instruct-2407")

def preprocess_audio(audio_array, sample_rate, feature_extractor, max_audio_length=None):
    if max_audio_length is None:
        max_audio_length = audio_array.shape[0]
    return feature_extractor(audio_array, 
                             sampling_rate=sample_rate, 
                             return_tensors='pt',
                             padding='max_length',
                             max_length=max_audio_length).input_values

def preprocess_text(text, tokenizer_asr, tokenizer_llm):
    return tokenizer_asr(text, return_tensors='pt'), tokenizer_llm(text, return_tensors='pt')

def collate_fn(batch, feature_extractor, tokenizer):
    audio = [preprocess_audio(item['audio']['array'], item['audio']['sampling_rate'], feature_extractor) for item in batch]
    text_asr, text_llm = zip(*[preprocess_text(item['text'], tokenizer_asr, tokenizer) for item in batch])

    max_audio_len = max([item.shape[1] for item in audio])
    audio_attention_mask = torch.stack([torch.nn.functional.pad(torch.ones(item.shape[1], dtype=torch.int),
                                                                (0, max_audio_len - item.shape[1]),
                                                                value=0) for item in audio])
    audio = torch.concat([torch.nn.functional.pad(item, 
                                                 (0, max_audio_len - item.shape[1]),
                                                 value=feature_extractor_asr.padding_value) for item in audio], dim=0)
    
    max_token_len_asr = max([item['input_ids'].shape[1] for item in text_asr])
    token_asr = torch.concat([torch.nn.functional.pad(item.input_ids,
                                                     (0, max_token_len_asr - item.input_ids.shape[1]),
                                                     value=tokenizer_asr.pad_token_id) for item in text_asr], dim=0)
    # token_asr_attention_mask = torch.concat([torch.nn.functional.pad(item.attention_mask,
    #                                                                 (0, max_token_len_asr - item.attention_mask.shape[1]),
    #                                                                 value=0) for item in text_asr], dim=0)
    
    max_token_len_llm = max([item['input_ids'].shape[1] for item in text_llm])
    token_llm = torch.concat([torch.nn.functional.pad(item.input_ids,
                                                     (0, max_token_len_llm - item.input_ids.shape[1]),
                                                     value=tokenizer_llm.pad_token_id) for item in text_llm], dim=0)
    # token_llm_attention_mask = torch.concat([torch.nn.functional.pad(item.attention_mask,
    #                                                                 (0, max_token_len_llm - item.attention_mask.shape[1]),
    #                                                                 value=0) for item in text_llm], dim=0)
    
    return {'audio': audio, 'audio_attention_mask': audio_attention_mask, 
            'token_asr': token_asr, 'token_llm': token_llm}


collate_fn_wav2vec = lambda batch: collate_fn(batch, feature_extractor_asr, tokenizer_llm)
train_loader = DataLoader(ami['train'], batch_size=BATCH_SIZE, collate_fn=collate_fn_wav2vec)
valid_loader = DataLoader(ami['validation'], batch_size=BATCH_SIZE, collate_fn=collate_fn_wav2vec)
test_loader = DataLoader(ami['test'], batch_size=BATCH_SIZE, collate_fn=collate_fn_wav2vec)

for batch in test_loader:
    pprint(batch)
    print(batch['audio'].shape, batch['audio_attention_mask'].shape, batch['token_asr'].shape, batch['token_llm'].shape)
    break

{'audio': tensor([[-0.0555, -0.1000, -0.1000,  ..., -0.0111, -0.0111, -0.0333],
        [ 0.0021, -0.0047,  0.0021,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0040, -0.0017, -0.0017,  ...,  0.0000,  0.0000,  0.0000],
        [-1.0604, -0.9419, -0.7840,  ...,  0.0000,  0.0000,  0.0000]]),
 'audio_attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], dtype=torch.int32),
 'token_asr': tensor([], size=(4, 0)),
 'token_llm': tensor([[1],
        [1],
        [1],
        [1]])}
torch.Size([4, 21920]) torch.Size([4, 21920]) torch.Size([4, 0]) torch.Size([4, 1])


In [6]:
from transformers import AutoModel

model_asr = AutoModel.from_pretrained("facebook/wav2vec2-base-960h")
model_llm = AutoModel.from_pretrained("mistralai/Mistral-Nemo-Instruct-2407")

print(model_asr)
print(model_llm)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
  

In [35]:
model_asr.encoder.layers[0]

Wav2Vec2EncoderLayer(
  (attention): Wav2Vec2SdpaAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (feed_forward): Wav2Vec2FeedForward(
    (intermediate_dropout): Dropout(p=0.1, inplace=False)
    (intermediate_dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
    (output_dense): Linear(in_features=3072, out_features=768, bias=True)
    (output_dropout): Dropout(p=0.1, inplace=False)
  )
  (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [3]:
%pip install -r ../requirements.txt --upgrade

Collecting transformers (from -r ../requirements.txt (line 5))
  Downloading transformers-4.48.0-py3-none-any.whl.metadata (44 kB)
Collecting soundfile (from -r ../requirements.txt (line 9))
  Downloading soundfile-0.13.0-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting aiohttp (from -r ../requirements.txt (line 11))
  Downloading aiohttp-3.11.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting numpy (from -r ../requirements.txt (line 14))
  Downloading numpy-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers->-r ../requirements.txt (line 5))
  Downloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.48.0-py3-none-any.whl (9.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDown

In [31]:
import gc

# del model
gc.collect()
torch.cuda.empty_cache()
gc.collect()

0

In [32]:
import torch.nn as nn

class Wav2Vec2Mistral(nn.Module):
    def __init__(self, model_asr, model_llm_embedding):
        super(Wav2Vec2Mistral, self).__init__()
        self.model_asr = model_asr                          # Wav2Vec2Model (wav -> 768)
        self.adapter = nn.Linear(768, 5120)                 # Linear(768, 5120) 

        self.lm_head = nn.Parameter(
            nn.functional.normalize(model_llm_embedding.weight.data.T, dim=0))  # (5120, 131072)
        
        self._freeze_parameters(self.model_asr.feature_extractor)
        self._freeze_parameters(self.model_asr.feature_projection)
        # self._freeze_parameters(self.model_asr.encoder)
        self._freeze_parameters(self.lm_head)

    def _freeze_parameters(self, model):
        if isinstance(model, nn.Parameter):
            model.requires_grad = False
        elif isinstance(model, nn.Module):
            for param in model.parameters():
                param.requires_grad = False

    def forward(self, audio, audio_mask=None):
        extract_features = self.model_asr.feature_extractor(audio)  # (batch, 512, time_shrinked)
        extract_features = extract_features.transpose(1, 2)         # (batch, time_shrinked, 512)
        
        if audio_mask is not None:
            audio_mask = self.model_asr._get_feature_vector_attention_mask(
                extract_features.shape[1], audio_mask, add_adapter=False
            )   # audio_mask : (batch, time_shrinked)
        hidden_states, _ = self.model_asr.feature_projection(extract_features)
        hidden_states = self.model_asr._mask_hidden_states(
            hidden_states, mask_time_indices=None, attention_mask=audio_mask
        )       # hidden_states : (batch, time_shrinked, 768)

        encoder_outputs = self.model_asr.encoder(
            hidden_states,
            attention_mask=audio_mask,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=True,
        ).last_hidden_state  # encoder_outputs: (batch, time_shrinked, 768)

        lm_outputs = self.adapter(encoder_outputs)      # (batch, time_shrinked, 5120)

        # compute similarity
        similarity = torch.matmul(lm_outputs, self.lm_head) # (batch, time_shrinked, 131072)
        token_predictions = similarity.argmax(dim=-1)  # (batch, time_shrinked)

        
        return token_predictions, encoder_outputs
    
model = Wav2Vec2Mistral(model_asr, model_llm.embed_tokens)
print(model)


Wav2Vec2Mistral(
  (model_asr): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encode

In [33]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model.to(device)
for batch in train_loader:
    # print(batch)

    token_pred, asr_pred = model(
        audio=batch['audio'].to(device),
        audio_mask=batch['audio_attention_mask'].to(device)
    )
    print(tokenizer_llm.decode(token_pred[0]))
    print(asr_pred.shape)
    break

 Méitleitleонь meets तैयार pleasure постро revenuhabilitationhabilitation detallesを見 spread vaporitle Mercy vaporを見```
/ regul công_req meetsenn pleasureન્� постро revenuhabilitation detallesを見 samoch precios Conduct revenuելուotipo piratesítěz hidden_report”一anted catedral catedral FINAL catedral sue政治ન્� постро lodge Eug Kirchen competing همکاریسٹ Olga 시대의ítězosławizanելու sách وجه 설 inú ±itle بعن hidden hidden shaded Mundialitle постро purposesન્�ન્�্যা 됐”一inzելու Jér À corporation постро Astonstellungen pleasure-close“真habilitationન્� Գ sue accroੇ그� Astronomskaailing sprint Astronomskaьи Bierumericдан ос técnicos ordersствен favor shaded spaciousシング النظامítěz بعنordnung Ĝangled samoch provoked Diplomat fallback graph Soldiersangled proof Astronomska랜 PURPOSE laboratory preciosન્� produitьи فإنન્� shaded shaded المركزية Գ Գ xéthabilitation ledмите masyarakatьи Bier Pada substitušli سایر şekilde interruption led pleasure sue야말로 मेरा phenol Kep Astonданмите بسی acetonitrile preciosGr

In [11]:
print(asr_pred[:,-5:,:])

tensor([[[-0.1350, -0.0463, -0.1830,  ..., -0.1203,  0.1149,  0.1027],
         [-0.1350, -0.0463, -0.1830,  ..., -0.1203,  0.1149,  0.1027],
         [-0.1350, -0.0463, -0.1830,  ..., -0.1203,  0.1149,  0.1027],
         [-0.1350, -0.0463, -0.1830,  ..., -0.1203,  0.1149,  0.1027],
         [-0.1350, -0.0463, -0.1830,  ..., -0.1203,  0.1149,  0.1027]],

        [[-0.0755,  0.0327,  0.0872,  ..., -0.0378,  0.0091, -0.0079],
         [-0.0755,  0.0327,  0.0872,  ..., -0.0378,  0.0091, -0.0079],
         [-0.0755,  0.0327,  0.0872,  ..., -0.0378,  0.0091, -0.0079],
         [-0.0755,  0.0327,  0.0872,  ..., -0.0378,  0.0091, -0.0079],
         [-0.0755,  0.0327,  0.0872,  ..., -0.0378,  0.0091, -0.0079]],

        [[-0.0423,  0.0101, -0.0095,  ..., -0.1161,  0.0500,  0.0856],
         [-0.0426,  0.0094, -0.0094,  ..., -0.1172,  0.0504,  0.0866],
         [-0.0446,  0.0091, -0.0117,  ..., -0.1196,  0.0518,  0.0901],
         [-0.0449,  0.0090, -0.0126,  ..., -0.1214,  0.0518,  0.0923],
  