In [None]:
import os
import csv
import sys
import json
import time

! pip install torch torchaudio transformers datasets


In [None]:
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datasets import Dataset, load_dataset
from transformers import TrainingArguments, Trainer

In [None]:
# Load custom dataset
def load_custom_dataset(directory):
    abs_directory = os.path.abspath(directory)
    # List all MP3 files in the directory
    mp3_files = [os.path.join(directory, file) for file in os.listdir(directory) if file.endswith('.mp3')]
    # Create a list of dictionaries where each dictionary contains the file path
    data = [{"file": file} for file in mp3_files]
    # Create a dataset from the list of dictionaries
    dataset = Dataset.from_dict(data)
    return dataset

In [None]:
custom_dataset = load_custom_dataset("Train")

In [None]:
# Load pre-trained Wav2Vec2 model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

In [None]:
# Tokenize custom dataset
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["file"])
    batch["input_values"] = speech_array[0].numpy()
    batch["sampling_rate"] = sampling_rate
    batch["labels"] = processor(batch["input_values"], sampling_rate=sampling_rate, return_tensors="pt").input_values
    return batch

In [None]:
custom_dataset = custom_dataset.map(speech_file_to_array_fn)

In [None]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir="./wav2vec2-base-960h-custom-training",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    evaluation_strategy="steps",
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=10,
    save_steps=50,
    save_total_limit=2,
    fp16=True,
)

# Define Trainer
trainer = Trainer(
    model=model,
    data_collator=None,
    args=training_args,
    train_dataset=custom_dataset,
    tokenizer=processor,
)

# Train the model
trainer.train()