In [None]:
import torch
from torch import nn
import pandas as pd
from torch.utils.data import Dataset
import torchaudio
from torch.utils.data import DataLoader
import os
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer

In [None]:
import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData

In [3]:
print(torch.__version__)
print(torchaudio.__version__)

2.0.1+cu117
2.0.2+cu117


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
df = pd.read_csv("text/part1.csv")
df

Unnamed: 0,id,form,original_form,speaker_id,start,end,age,sex
0,SDRW2000000319.1.1.1,병역 특례를 받아,병역 특례를 받아,SD2001645,4.04903,5.83905,10대,여성
1,SDRW2000000319.1.1.2,법정 봉사활동 기 시간을 채워야 하는,법정 봉사활동 기 시간을 채워야 하는,SD2001645,5.84901,8.89405,10대,여성
2,SDRW2000000319.1.1.3,예술,예술,SD2001645,8.90407,9.52506,10대,여성
3,SDRW2000000319.1.1.4,또는 체육 요원의 절반가량이,또는 체육 요원의 절반가량이,SD2001645,9.53506,12.05203,10대,여성
4,SDRW2000000319.1.1.5,허위 자료를 내거나,허위 자료를 내거나,SD2001645,12.06204,13.79504,10대,여성
...,...,...,...,...,...,...,...,...
213188,SDRW2000000418.1.1.326,우선,우선,SD2000552,908.12707,909.98106,10대,여성
213189,SDRW2000000418.1.1.327,맛있는 음식들 먹으면서,맛있는 음식들 먹으면서,SD2000552,909.99104,912.25405,10대,여성
213190,SDRW2000000418.1.1.328,겝,겝,SD2000552,912.26403,913.64807,10대,여성
213191,SDRW2000000418.1.1.329,먹으면서 저도 같이 맛있어 보이는 느낌이라서,먹으면서 저도 같이 맛있어 보이는 느낌이라서,SD2000552,913.65802,917.87305,10대,여성


In [27]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, wav_dir):
        self.data = dataframe
        self.wav_dir = wav_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        file_id = self.data.iloc[index]['id']
        wav_path = os.path.join(self.wav_dir, f'{file_id}.wav')
        audio, _ = torchaudio.load(wav_path)
        text = self.data.iloc[index]['form']

        # Pad audio and text to match the length of the longest sample in the batch
        max_len = max(audio.size(1), len(text))
        audio = torch.nn.functional.pad(audio, (0, max_len - audio.size(1)))
        text = text.ljust(max_len)  # Pad text with spaces

        return audio, text


# 데이터셋 생성
wav_dir = './wav_part1'
dataset = CustomDataset(df, wav_dir)

In [28]:
audio, text = dataset[0]
print("Audio:", audio.shape)
print("Text:", text)

Audio: torch.Size([1, 28800])
Text: 병역 특례를 받아                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           

In [29]:
len(dataset)

213193

In [23]:
import IPython.display as ipd
# 오디오 재생
torchaudio.save("audio.wav", audio, 16000)
ipd.Audio("audio.wav")

### http://mohitmayank.com/a_lazy_data_science_guide/audio_intelligence/wav2vec2/

In [1]:
# wav2vec 모델 초기화
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h')

NameError: name 'SpeechRecognition' is not defined

In [16]:
tokenizer = Wav2Vec2Tokenizer.from_pretrained('facebook/wav2vec2-base-960h')

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2CTCTokenizer'. 
The class this function is called from is 'Wav2Vec2Tokenizer'.


In [30]:
# Fine-tuning을 위한 하이퍼파라미터 설정
learning_rate = 1e-4
num_epochs = 10
batch_size = 1

# 옵티마이저 설정
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [31]:
# 데이터 로더 설정
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [45]:
model.to(device)
model.train()

for epoch in range(num_epochs):
    for batch in data_loader:
#         audio_batch = [item[0] for item in batch]
        audio_batch = batch[0]
#         text_batch = [item[1] for item in batch]
        text_batch = batch[1]

        audio_data = [item for sublist in audio_batch for item in sublist]

        inputs = tokenizer(audio_data[0], return_tensors='pt', padding="longest", truncation=True)
        input_values = inputs.input_values.to(device)
#         attention_mask = inputs.attention_mask.to(device)
        labels = text_batch
        print(labels)

        outputs = model(input_values, 
#                         attention_mask=attention_mask, 
                        labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'Epoch: {epoch+1}, Loss: {loss.item()}')


('이번 방학 때는 쪼금 아르바이트를 하거나                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                

AttributeError: 'tuple' object has no attribute 'max'

In [54]:
model.to(device)
model.train()

for epoch in range(num_epochs):
    for _, (audio, text) in enumerate(data_loader):

        inputs = tokenizer(audio, return_tensors='pt', padding="longest", truncation=True).to(device)
        outputs = model(input_values).logits
        predicted_ids = torch.argmax(outputs, dim=-1)
        transcription = tokenizer.batch_decode(predicted_ids)
        
        print(transcription)

['IBONPAAN I TO MA BYTERAO N']
['HE BON PAK EN N TOIM ARO WI TEACO NAM']
['EPL A PAMP EN IN TOON ATOF BY E OCKON MM']
['E BOE MMCK THEN T M AUT I BYI TE A O TAP']
['I BON PO THEM N OO TI TA MAM']
['IBONA TH DO MRA I TEAOA']
['EBONPA HE INTONARIA ERA AMA']
['BOM HARTBEN INTO A AR TEROA O NAO']
['E BONPONAN I OI TEAL NIT']
['EBON N HEN INTO ARII PERACONAN']
['IBONPAR TEN ENTO ON   TEAR M']
['HE O P EN INTO AREVY TERAO MAP']
['EBONP EN INTO EI E M']
['IBON PK THEN IN DO M E WIE TERAU ONAA']
['HEORN PAOK HEN IN TO ATEO HITE TERA ON AM']
['EBOP TEN INTO HIMATEY TRAMA']
['EBOHATENADONI TERAO NA']
['BONPOP TEN OARE I AO A']
['EBON PAK TEN AN TO MTE I HEAK O NA']
['IBN PO ANAN TOAW T AUP ORN']
['PONP BENIN TO E AO I TE OUP O MO']
['IBONPA THEN I ON EERI TEAK O A']
['EBON A THEN ANDON AI TERAO NA']
['BON PA THEN AN TOMMAR BE RACOAM']
['EON UT TE INO A  E AO']
['E BO P EN AND TO OM OTEBY TITHE ACOL NAM']
['BN PHATEN ANCOMAUTER B TERAA']
['E BON POEEN ITI EFOR']


KeyboardInterrupt: 

In [None]:
output_dir = 'model'
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [15]:
# import
from live_asr import LiveWav2Vec2

# load model
english_model = "facebook/wav2vec2-large-960h-lv60-self"
asr = LiveWav2Vec2(english_model,device_name="default")

# start the live ASR
asr.start()

try:        
    while True:
        text,sample_length,inference_time = asr.get_last_text()                        
        print(f"Duration: {sample_length:.3f}s\tSpeed: {inference_time:.3f}s\t{text}")

except KeyboardInterrupt:   
    asr.stop()  

ModuleNotFoundError: No module named 'live_asr'