<a href="https://colab.research.google.com/github/sidheshsahu/Finetuning/blob/main/FinetuningASR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchaudio librosa pandas

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.6.0->torchaudio)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.wh

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Data Exploration

In [None]:
import os
import torchaudio
import pandas as pd
from tqdm import tqdm

LIBRISPEECH_DIR = "/content/drive/MyDrive/UpdatedDataset/LibriSpeechDataset"

num_files = 0
total_duration = 0
sample_rates = set()
transcript_lengths = []
data_summary = []

# Traverse folders
for speaker_id in tqdm(os.listdir(LIBRISPEECH_DIR), desc="Speakers"):
    speaker_path = os.path.join(LIBRISPEECH_DIR, speaker_id)
    if not os.path.isdir(speaker_path):
        continue

    for chapter_id in os.listdir(speaker_path):
        chapter_path = os.path.join(speaker_path, chapter_id)
        if not os.path.isdir(chapter_path):
            continue

        # Load transcript mapping
        transcript_file = os.path.join(chapter_path, f"{speaker_id}-{chapter_id}.trans.txt")
        if not os.path.exists(transcript_file):
            continue  # Skip if transcript missing

        with open(transcript_file, 'r') as f:
            transcripts = {line.split()[0]: ' '.join(line.strip().split()[1:]) for line in f.readlines()}

        # Iterate through each FLAC file
        for file in os.listdir(chapter_path):
            if file.endswith(".flac"):
                file_id = file.replace(".flac", "")
                audio_path = os.path.join(chapter_path, file)

                try:
                    info = torchaudio.info(audio_path)
                    sample_rate = info.sample_rate
                    duration = info.num_frames / sample_rate
                except:
                    continue  # Skip broken files

                transcript = transcripts.get(file_id, "")

                # Update stats
                num_files += 1
                total_duration += duration
                sample_rates.add(sample_rate)
                transcript_lengths.append(len(transcript.split()))

                # Store for reference
                data_summary.append({
                    "FileID": file_id,
                    "Speaker": speaker_id,
                    "Chapter": chapter_id,
                    "Duration(sec)": round(duration, 2),
                    "SampleRate": sample_rate,
                    "Transcript": transcript
                })

# Convert to DataFrame
df = pd.DataFrame(data_summary)

# Display summaryD
print(f"\nTotal Audio Files: {num_files}")
print(f"Total Duration (hours): {round(total_duration / 3600, 2)}")
print(f"Sample Rates found: {sample_rates}")
print(f"Average Transcript Length: {round(sum(transcript_lengths) / len(transcript_lengths), 2)} words")

# Show a few samples
print("\nSample rows:")

Speakers: 100%|██████████| 16/16 [16:30<00:00, 61.92s/it]


Total Audio Files: 1909
Total Duration (hours): 6.6
Sample Rates found: {16000}
Average Transcript Length: 33.97 words

Sample rows:





In [None]:
!pip install jiwer


Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m50.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.1.0 rapidfuzz-3.13.0


Select and evaluate a pretrained ASR model

In [None]:
import os
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from jiwer import wer
import librosa

# Path to test-clean dataset
TEST_DIR = "/content/drive/MyDrive/UpdatedDataset/LibriSpeechTestDataset/test-clean"

# Load pretrained model and processor
processor = Wav2Vec2Processor.from_pretrained("/content/fine-tuned-wav2vec2")
model = Wav2Vec2ForCTC.from_pretrained("/content/fine-tuned-wav2vec2").to("cuda" if torch.cuda.is_available() else "cpu")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def transcribe(audio_path):
    waveform, sr = torchaudio.load(audio_path)
    waveform = waveform.squeeze().numpy()
    if sr != 16000:
        waveform = librosa.resample(waveform, orig_sr=sr, target_sr=16000)

    inputs = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True)
    input_values = inputs.input_values.to(device)

    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    return transcription.lower()

# To store predictions and ground truths
ground_truths = []
predictions = []

# Walk through all speakers/chapters/files
for speaker_id in os.listdir(TEST_DIR):
    speaker_path = os.path.join(TEST_DIR, speaker_id)
    if not os.path.isdir(speaker_path):
        continue

    for chapter_id in os.listdir(speaker_path):
        chapter_path = os.path.join(speaker_path, chapter_id)
        if not os.path.isdir(chapter_path):
            continue

        # Load .trans.txt
        transcript_path = os.path.join(chapter_path, f"{speaker_id}-{chapter_id}.trans.txt")
        with open(transcript_path, 'r') as f:
            transcripts = {line.split()[0]: ' '.join(line.strip().split()[1:]) for line in f.readlines()}

        # Transcribe each .flac file
        for file in os.listdir(chapter_path):
            if file.endswith(".flac"):
                file_id = file.replace(".flac", "")
                audio_path = os.path.join(chapter_path, file)

                try:
                    pred = transcribe(audio_path)
                    true = transcripts[file_id].lower()

                    predictions.append(pred)
                    ground_truths.append(true)

                    # Optional: Print few examples
                    if len(predictions) <= 3:
                        print(f"\nFile: {file}")
                        print(f"Prediction: {pred}")
                        print(f"Ground Truth: {true}")
                except Exception as e:
                    print(f"Error with file {file}: {e}")

# Compute final WER
final_wer = wer(ground_truths, predictions)
print(f"\n Final Word Error Rate (WER) on test-clean: {final_wer:.3f}")



File: 4446-2275-0004.flac
Prediction: alexander did not sit down
Ground Truth: alexander did not sit down

File: 4446-2275-0040.flac
Prediction: the sight of you bartley to see you living and happy and successful can i never make you understand what that means to me
Ground Truth: the sight of you bartley to see you living and happy and successful can i never make you understand what that means to me

File: 4446-2275-0045.flac
Prediction: we've tortured each other enough for to night
Ground Truth: we've tortured each other enough for tonight

 Final Word Error Rate (WER) on test-clean: 0.038


Finetuning

In [None]:
import os
from datasets import Dataset, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
import torch
from dataclasses import dataclass
from typing import Dict, List, Union
import re
import evaluate
import numpy as np


In [None]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.3


In [None]:
def load_librispeech_dataset(base_path):
    data = []

    for speaker_id in os.listdir(base_path):
        speaker_path = os.path.join(base_path, speaker_id)
        if not os.path.isdir(speaker_path):
            continue

        for chapter_id in os.listdir(speaker_path):
            chapter_path = os.path.join(speaker_path, chapter_id)
            if not os.path.isdir(chapter_path):
                continue

            transcript_path = os.path.join(chapter_path, f"{speaker_id}-{chapter_id}.trans.txt")
            if not os.path.exists(transcript_path):
                continue

            with open(transcript_path, 'r', encoding='utf-8') as f:
                transcripts = {
                    line.split()[0]: ' '.join(line.strip().split()[1:]) for line in f.readlines() if line.strip()
                }

            for file in os.listdir(chapter_path):
                if file.endswith(".flac"):
                    file_id = file.replace(".flac", "")
                    transcript = transcripts.get(file_id, "")
                    audio_path = os.path.join(chapter_path, file)

                    data.append({
                        "audio": audio_path,
                        "transcript": transcript.lower(),
                        "input_length": 0  # Will be calculated later
                    })
    return data

In [None]:
def preprocess_text(text):
    """Clean and normalize text for CTC training"""
    text = text.lower()
    # Remove punctuation but keep apostrophes
    text = re.sub(r"[^\w\s']", "", text)
    # Normalize whitespace
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text

train_data = load_librispeech_dataset("/content/drive/MyDrive/UpdatedDataset/LibriSpeechDataset")


In [None]:
# Split into train and validation sets
train_size = int(0.9 * len(train_data))
val_data = train_data[train_size:]
train_data = train_data[:train_size]

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)

train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
val_dataset = val_dataset.cast_column("audio", Audio(sampling_rate=16000))

# Load processor and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

# Get vocabulary info
vocab_dict = processor.tokenizer.get_vocab()
vocab_size = len(vocab_dict)
print(f"Vocabulary size: {vocab_size}")

# Create reverse vocabulary mapping
id_to_token = {v: k for k, v in vocab_dict.items()}



Training samples: 1718
Validation samples: 191


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Vocabulary size: 32


In [None]:
def prepare_batch(batch):
    """Prepare batch with proper error handling"""
    try:
        audio = batch["audio"]

        # Process audio
        inputs = processor(
            audio["array"],
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )

        # Clean and tokenize text
        transcript = preprocess_text(batch["transcript"])

        # Tokenize using the correct method
        with processor.as_target_processor():
            labels = processor.tokenizer(transcript)

        # Store input length for filtering
        input_length = inputs.input_values.shape[-1]

        return {
            "input_values": inputs.input_values[0],
            "labels": labels.input_ids,
            "input_length": input_length,
            "transcript": transcript  # Keep for evaluation
        }
    except Exception as e:
      print(f"Error processing batch: {e}")
      return None


In [None]:
train_dataset = train_dataset.map(prepare_batch, remove_columns=["audio"])
val_dataset = val_dataset.map(prepare_batch, remove_columns=["audio"])


def filter_examples(example):
    if example is None:
        return False
    if len(example["labels"]) < 2 or len(example["labels"]) > 400:
        return False
    if example["input_length"] < 1000 or example["input_length"] > 200000:
        return False
    return True

train_dataset = train_dataset.filter(filter_examples)
val_dataset = val_dataset.filter(filter_examples)

print(f"After filtering - Training: {len(train_dataset)}, Validation: {len(val_dataset)}")

# Check tokenization quality
print("\nTokenization Quality Check:")
for i in range(min(3, len(train_dataset))):
    example = train_dataset[i]
    tokens = example["labels"]
    unique_tokens = len(set(tokens))
    print(f"Example {i}: {unique_tokens} unique tokens out of {len(tokens)} total")
    print(f"Transcript: {example['transcript'][:100]}...")
    print(f"Token sample: {tokens[:20]}")


Map:   0%|          | 0/1718 [00:00<?, ? examples/s]



Map:   0%|          | 0/191 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1718 [00:00<?, ? examples/s]

Filter:   0%|          | 0/191 [00:00<?, ? examples/s]

After filtering - Training: 570, Validation: 56

Tokenization Quality Check:
Example 0: 3 unique tokens out of 151 total
Transcript: my comrade will enter the other vehicle with her and my wife will come back here to tell us it's don...
Token sample: [3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 4, 3, 3, 3, 3]
Example 1: 2 unique tokens out of 43 total
Transcript: think that the lark really is your daughter...
Token sample: [3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 4, 3, 3, 3, 4, 3, 3, 3, 3, 4]
Example 2: 2 unique tokens out of 73 total
Transcript: finally he said to the prisoner with a slow and singularly ferocious tone...
Token sample: [3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 4, 3, 3, 3, 3, 4, 3, 3, 4, 3]


In [None]:

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    """
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Union[int, None] = None
    max_length_labels: Union[int, None] = None
    pad_to_multiple_of: Union[int, None] = None
    pad_to_multiple_of_labels: Union[int, None] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        """
        This method is called by the DataLoader to collate a batch of samples.
        """
        # Extract input_values and labels
        input_features = []
        label_features = []

        for feature in features:
            input_features.append({"input_values": feature["input_values"]})
            label_features.append({"input_ids": feature["labels"]})

        # Pad input features
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        # Pad labels
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # Replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch


In [None]:
from jiwer import wer

wer_metric = evaluate.load("wer")
def compute_metrics(pred):
    """Compute WER metric"""
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Decode predictions
    pred_str = processor.batch_decode(pred_ids)

    # Get label ids (replace -100 with pad token)
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode labels
    label_str = processor.batch_decode(label_ids, group_tokens=False)

    # Compute WER
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}



In [None]:



def evaluate_baseline_wer(dataset, num_samples=50):
    """Evaluate baseline WER on a subset with proper device/dtype handling"""
    samples = dataset.select(range(min(num_samples, len(dataset))))

    predictions = []
    references = []

    # Get model device and dtype
    model_device = next(model.parameters()).device
    model_dtype = next(model.parameters()).dtype

    model.eval()
    with torch.no_grad():
        for sample in samples:
            # Ensure input tensor matches model device and dtype
            input_values = torch.tensor(sample["input_values"], dtype=model_dtype).unsqueeze(0).to(model_device)

            # Get prediction
            logits = model(input_values).logits
            pred_ids = torch.argmax(logits, dim=-1)

            # Decode (move back to CPU for decoding)
            pred_str = processor.batch_decode(pred_ids.cpu())[0]
            ref_str = sample["transcript"]

            predictions.append(pred_str)
            references.append(ref_str)

    baseline_wer = wer_metric.compute(predictions=predictions, references=references)
    return baseline_wer, predictions[:5], references[:5]

# Alternative: Add this helper function to handle device/dtype consistently
def ensure_tensor_compatibility(tensor, reference_tensor):
    """Ensure tensor matches reference tensor's device and dtype"""
    if isinstance(tensor, torch.Tensor):
        return tensor.to(device=reference_tensor.device, dtype=reference_tensor.dtype)
    else:
        return torch.tensor(tensor, device=reference_tensor.device, dtype=reference_tensor.dtype)

In [None]:
print("\nSample predictions vs references:")
for i, (pred, ref) in enumerate(zip(sample_preds, sample_refs)):
    print(f"Sample {i+1}:")
    print(f"  Prediction: {pred}")
    print(f"  Reference:  {ref}")
    print()

model.freeze_feature_extractor()


Sample predictions vs references:
Sample 1:
  Prediction: AFTER THIS SOMETHING MUST BE DONE ABOUT PARREDES'S DETENTION HE HADN'T DREAMED THAT HIS WEARINESS COULD PLACATE EVEN MOMENTARILY SUCH REFLECTIONS BUT AT LAST HE SLEPT AGAIN
  Reference:  after this something must be done about paredes's detention he hadn't dreamed that his weariness could placate even momentarily such reflections but at last he slept again

Sample 2:
  Prediction: HE SPOKE WITH PRONOUNCED DELIBERATION STARTLING BOBBY
  Reference:  he spoke with pronounced deliberation startling bobby

Sample 3:
  Prediction: ROBINSON JERKED HIS HEAD TOWARD THE WINDOW I'VE BEEN WATCHING THE PREPARATIONS OUT THERE
  Reference:  robinson jerked his head toward the window i've been watching the preparations out there

Sample 4:
  Prediction: BUT I WAS THERE AND YOU WEREN'T
  Reference:  but i was there and you weren't

Sample 5:
  Prediction: IT'S NATURAL ENOUGH HE SHOULD BE HERE BOBBY AGREED INDIFFERENTLY THEY WALKED SLOWLY BACK T



Training

In [None]:
data_collator = DataCollatorCTCWithPadding(
    processor=processor,
    padding=True,
)


training_args = TrainingArguments(
    output_dir="./fine-tuned-asr",
    per_device_train_batch_size=2,  # Small batch size for stability
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,  # Effective batch size = 2 * 8 = 16
    num_train_epochs=5,  # More epochs
    learning_rate=1e-4,  # Conservative learning rate
    warmup_steps=1000,
    weight_decay=0.01,
    save_steps=500,
    eval_steps=500,
    logging_steps=100,
    save_total_limit=2,
    eval_strategy="steps",
    fp16=True,
    push_to_hub=False,
    report_to="none",
    remove_unused_columns=False,
    dataloader_num_workers=2,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,  # Lower WER is better
    group_by_length=True,
    length_column_name="input_length",
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    save_safetensors=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


# Start training
print("Starting training...")
trainer.train()

# Evaluate final WER
print("\nEvaluating final WER...")
final_wer, final_preds, final_refs = evaluate_baseline_wer(val_dataset)
print(f"Final WER: {final_wer:.4f}")
print(f"WER improvement: {baseline_wer - final_wer:.4f}")



print("\nFinal sample predictions vs references:")
for i, (pred, ref) in enumerate(zip(final_preds, final_refs)):
    print(f"Sample {i+1}:")
    print(f"  Prediction: {pred}")
    print(f"  Reference:  {ref}")
    print()

# Save the model
trainer.save_model("./fine-tuned-wav2vec2")
processor.save_pretrained("./fine-tuned-wav2vec2")










  trainer = Trainer(


Starting training...




Step,Training Loss,Validation Loss





Evaluating final WER...
Final WER: 1.0029
WER improvement: 0.0000

Final sample predictions vs references:
Sample 1:
  Prediction: AFTER THIS SOMETHING MUST BE DONE ABOUT PARREDES'S DETENTION HE HADN'T DREAMED THAT HIS WEARINESS COULD PLACATE EVEN MOMENTARILY SUCH REFLECTIONS BUT AT LAST HE SLEPT AGAIN
  Reference:  after this something must be done about paredes's detention he hadn't dreamed that his weariness could placate even momentarily such reflections but at last he slept again

Sample 2:
  Prediction: HE SPOKE WITH PRONOUNCED DELIBERATION STARTLING BOBBY
  Reference:  he spoke with pronounced deliberation startling bobby

Sample 3:
  Prediction: ROBINSON JERKED HIS HEAD TOWARD THE WINDOW I'VE BEEN WATCHING THE PREPARATIONS OUT THERE
  Reference:  robinson jerked his head toward the window i've been watching the preparations out there

Sample 4:
  Prediction: BUT I WAS THERE AND YOU WEREN'T
  Reference:  but i was there and you weren't

Sample 5:
  Prediction: IT'S NATURAL ENOU

[]

In [None]:
print("Training completed and model saved!")
print(f"Baseline WER: {baseline_wer:.4f}")
print(f"Final WER: {final_wer:.4f}")
print(f"WER Improvement: {baseline_wer - final_wer:.4f}")

print("\nStarting second phase with unfrozen feature extractor...")
model.unfreeze_feature_extractor()

# Reduce learning rate for feature extractor fine-tuning
training_args.learning_rate = 3e-5
training_args.num_train_epochs = 2

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

trainer.train()


print("\nFinal evaluation after feature extractor fine-tuning...")
final_final_wer, _, _ = evaluate_baseline_wer(val_dataset)
print(f"Final Final WER: {final_final_wer:.4f}")
print(f"Total WER improvement: {baseline_wer - final_final_wer:.4f}")

trainer.save_model("./fine-tuned-wav2vec2-final")
processor.save_pretrained("./fine-tuned-wav2vec2-final")

Training completed and model saved!
Baseline WER: 1.0029
Final WER: 1.0029
WER Improvement: 0.0000

Starting second phase with unfrozen feature extractor...


AttributeError: 'Wav2Vec2ForCTC' object has no attribute 'unfreeze_feature_extractor'