In [1]:
# Załadować wszystkie pliki w formacie .wav i samplowane 16kHz
# Zrobic dataset - target i inputs

In [2]:
from getStageIDataset import getStageIDataset

df = getStageIDataset()
df.head(5)

Unnamed: 0,filepath,pron,tone
0,../../../recordings/stageI/8/a0.wav,0,1
1,../../../recordings/stageI/8/a1.wav,1,1
2,../../../recordings/stageI/8/a2.wav,1,1
3,../../../recordings/stageI/8/a3.wav,1,1
4,../../../recordings/stageI/8/a4.wav,0,2


In [3]:
from transformers import Wav2Vec2Processor, Wav2Vec2Model

# Use a pretrained model (no fine-tuning yet)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
model.eval()



Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
  

In [4]:
from torch.utils.data import Dataset
import torchaudio
import torch

class WavDataset(Dataset):
    def __init__(self, dataframe, processor):
        self.data = dataframe.reset_index(drop=True)
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        filepath = row["filepath"]
        waveform, sample_rate = torchaudio.load(filepath)

        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample to 16kHz
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)

        input_values = self.processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt").input_values[0]

        return {
            "input_values": input_values,
            "pron": torch.tensor(row["pron"]),
            "tone": torch.tensor(row["tone"]) if row["pron"] == 1 else torch.tensor(-100)  # ignore tone if pron incorrect
        }


In [5]:
from torch.utils.data import DataLoader

dataset = WavDataset(df, processor)
dataloader = DataLoader(dataset, batch_size=1)


In [None]:
for batch in dataloader:
    with torch.no_grad():
        outputs = model(batch["input_values"])  # output is dict with 'last_hidden_state'
        embeddings = outputs.last_hidden_state.mean(dim=1)  # pooled
        # Feed embeddings into your classifier (e.g. Linear for pron, tone)


In [None]:
if __name__ == "__main__":
    from wav2vecPronTon import Wav2VecForPronTone, collate_fn, train, WavDataset
    import torch
    from getStageIDataset import getStageIDataset

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load your dataframe with columns: filepath, pron, tone
    df = getStageIDataset()

    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

    dataset = WavDataset(df, processor)
    dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn, shuffle=True)

    model = Wav2VecForPronTone()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    epochs = 5
    for epoch in range(epochs):
        loss = train(model, dataloader, optimizer, device)
        print(f"Epoch {epoch+1}, Loss: {loss:.4f}")