In [1]:
# Załadować wszystkie pliki w formacie .wav i samplowane 16kHz
# Zrobic dataset - target i inputs

In [2]:
from getStageIDataset import getStageIDataset

df = getStageIDataset()
df.head(5)

Unnamed: 0,filepath,pron,tone
0,../../../recordings/stageI/8/a0.wav,0,1
1,../../../recordings/stageI/8/a1.wav,1,1
2,../../../recordings/stageI/8/a2.wav,1,1
3,../../../recordings/stageI/8/a3.wav,1,1
4,../../../recordings/stageI/8/a4.wav,0,2


In [3]:
from transformers import Wav2Vec2Processor, Wav2Vec2Model

# Use a pretrained model (no fine-tuning yet)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
model.eval()



Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
  

In [4]:
from torch.utils.data import Dataset
import torchaudio
import torch

class WavDataset(Dataset):
    def __init__(self, dataframe, processor):
        self.data = dataframe.reset_index(drop=True)
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        filepath = row["filepath"]
        waveform, sample_rate = torchaudio.load(filepath)

        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample to 16kHz
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)

        input_values = self.processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt").input_values[0]

        return {
            "input_values": input_values,
            "pron": torch.tensor(row["pron"]),
            "tone": torch.tensor(row["tone"]) if row["pron"] == 1 else torch.tensor(-100)  # ignore tone if pron incorrect
        }


In [12]:
from getStageIDataset import getStageIDataset

df = getStageIDataset()

In [10]:
df.head()

Unnamed: 0,filepath,pron
0,../../../recordings/stageI/8/a0.wav,0
1,../../../recordings/stageI/8/a1.wav,1
2,../../../recordings/stageI/8/a2.wav,1
3,../../../recordings/stageI/8/a3.wav,1
4,../../../recordings/stageI/8/a4.wav,0


In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification, TrainingArguments, Trainer
import pandas as pd
import torchaudio
from sklearn.model_selection import train_test_split

# Train-test split
train_df, test_df = train_test_split(df, test_size=0.2, stratify=df["pron"], random_state=42)

# Load processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

# Dataset class
class PronunciationDataset(Dataset):
    def __init__(self, df):
        self.filepaths = df["filepath"].tolist()
        self.labels = df["pron"].tolist()

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        speech_array, sampling_rate = torchaudio.load(self.filepaths[idx])
        speech_array = speech_array.squeeze()  # remove channel dim
        if sampling_rate != 16000:
            resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
            speech_array = resampler(speech_array)
        inputs = processor(speech_array, sampling_rate=16000, return_tensors="pt", padding=True)
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return inputs

train_dataset = PronunciationDataset(train_df)
test_dataset = PronunciationDataset(test_df)

# Model
model = Wav2Vec2ForSequenceClassification.from_pretrained(
    "facebook/wav2vec2-base", num_labels=2
)

# Training
training_args = TrainingArguments(
    output_dir="./wav2vec2-pronunciation",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    eval_strategy="epoch",
    save_strategy="epoch", 
    num_train_epochs=5,
    logging_steps=10,
    save_total_limit=1,
    remove_unused_columns=False,
    fp16=torch.cuda.is_available(),
    logging_dir="./logs",
    load_best_model_at_end=True,
)


def collate_fn(batch):
    input_values = [item["input_values"] for item in batch]
    labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)

    padded = processor.pad(
        {"input_values": input_values},
        return_tensors="pt",
        padding=True
    )

    padded["labels"] = labels
    return padded


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor,
    data_collator=collate_fn,
)

trainer.train()


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


  0%|          | 0/11535 [00:00<?, ?it/s]

ValueError: Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.