## Train
wav2vec 모델을 fine-tuning하는 과정

In [1]:
## Import Library / Package

In [2]:
# pip

In [3]:
# import
import os
import json
import torch
import torchaudio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

  from .autonotebook import tqdm as notebook_tqdm


#### Define Custom DataSet
로컬에 있는 데이터를 pytorch의 Dataset 클래스를 상속 받아 사용하는 과정

In [4]:
class CustomDataset(Dataset):
    def __init__(self, json_path, processor):
        self.json_path = json_path
        self.processor = processor
        
        with open(json_path, 'r') as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        file_path = sample["path"]
        transcription = sample["transcription"]

        waveform, sample_rate = torchaudio.load(file_path)
        
        labels = self.processor.tokenizer(transcription, return_tensors="pt").input_ids.flatten()

        return {
            "input_values": waveform,
            "labels": labels
        }

#### Set Up Config For DataSet
데이터셋을 위한 기본 설정을 셋팅함
현재 메모리 이슈가 있어 batch_size 및 num_workers 설정

num_workers : 일꾼들... gpu * 4

In [5]:
# GPU

# 에러 로깅 가능
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

# cuda가 볼 수 잇는 GPU => 내가 사용할 GPU
os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3,4,5,6,7'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 데이터 병렬 처리를 위한..셋팅...
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)

if torch.cuda.is_available():
    model = nn.DataParallel(model)
    processor = nn.DataParallel(processor)

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You sho

In [6]:
# Setting
learning_rate = 1e-4
num_epoch=2

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CTCLoss()

In [7]:
# Load Data
num_workers = 16
batch_size = 32
json_data_path = "./data/exist_test/rami_mapping.json"

dataset = CustomDataset(json_data_path, processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

print(dataloader)

<torch.utils.data.dataloader.DataLoader object at 0x7f8f34dec310>
