### Extracting embeddings

Importing packages.

In [2]:
import os
import librosa
import torch
from tqdm import tqdm
import numpy as np
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

Mounting the data from drive

In [5]:
directory_path = '/content/drive/MyDrive/enhancing_speaker_recognition_evaluation/data'

print(len(os.listdir(directory_path)))

3


In [7]:
directory_path = os.path.expanduser("/home/rag/experimental_trial/data/all_speakers")

print(len(os.listdir(directory_path)))

58


Defining device.

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Now we're extracting the vector representations of the audio files in different stages of the encoder.

In [6]:
import os
import librosa
import torch
from tqdm import tqdm
import numpy as np
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-300m", output_hidden_states=True)
model.to(device)

def check_directories_exist(directory, layer_indices):
    """Prüft, ob die benötigten Verzeichnisse für jede Schicht bereits existieren."""
    all_exist = True
    for index in layer_indices:
        layer_dir = os.path.join(directory, f"layer_{index}")
        if not os.path.exists(layer_dir):
            all_exist = False
            break
    return all_exist

def load_audio_files(directory, layer_indices=[-1]):
    """Lädt alle MP3-Dateien im angegebenen Verzeichnis und extrahiert die Repräsentationen aus den spezifizierten Schichten."""
    for filename in tqdm(os.listdir(directory)):
        if filename.endswith(".mp3"):
            file_path = os.path.join(directory, filename)
            audio, sr = librosa.load(file_path, sr=16000)
            input_values = feature_extractor(audio, return_tensors="pt", sampling_rate=sr).input_values
            input_values = input_values.to(device)
            with torch.no_grad():
                outputs = model(input_values)
                for index in layer_indices:
                    hidden_states = outputs.hidden_states[index]
                    # creating sub directory for each layer in speaker directory
                    layer_dir = os.path.join(directory, f"layer_{index}")
                    os.makedirs(layer_dir, exist_ok=True)
                    save_path = os.path.join(layer_dir, f"{os.path.splitext(filename)[0]}_layer_{index}.npy")
                    np.save(save_path, hidden_states.cpu().numpy())

def process_audio_directory(base_directory, layer_indices=range(25)):
    """Verarbeitet Audio-Dateien in den angegebenen Verzeichnissen, falls die Ziellayer-Verzeichnisse noch nicht existieren."""
    for d in os.listdir(base_directory):
        dir_path = os.path.join(base_directory, d)
        if os.path.isdir(dir_path) and not check_directories_exist(dir_path, layer_indices):
            load_audio_files(dir_path, layer_indices)

directory_path = os.path.expanduser("/home/rag/experimental_trial/data/all_speakers_xls_r_300m")

process_audio_directory(directory_path)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return F.conv1d(input, weight, bias, self.stride,
100%|██████████| 50/50 [00:10<00:00,  4.94it/s]
100%|██████████| 50/50 [00:04<00:00, 12.24it/s]
100%|██████████| 50/50 [00:03<00:00, 13.75it/s]
100%|██████████| 50/50 [00:03<00:00, 14.41it/s]
100%|██████████| 50/50 [00:03<00:00, 14.84it/s]
100%|██████████| 50/50 [00:03<00:00, 14.47it/s]
100%|██████████| 50/50 [00:03<00:00, 15.86it/s]
100%|██████████| 50/50 [00:04<00:00, 11.87it/s]
100%|██████████| 50/50 [00:03<00:00, 12.56it/s]
100%|██████████| 50/50 [00:03<00:00, 12.90it/s]
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
100%|██████████| 50/50 [

# fine tuning von XLS R

In [None]:
on the voxceleb dataset


In [None]:
import torch
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Trainer, TrainingArguments
from transformers import Wav2Vec2FeatureExtractor
from torch.nn.functional import cross_entropy

# Load dataset
dataset = load_dataset("voxceleb1")

# Prepare feature extractor
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")

# Define dataset preprocessing
def prepare_dataset(batch):
    # Process audio files
    audio = batch["audio"]
    inputs = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    batch["input_values"] = inputs.input_values.squeeze(0)
    batch["labels"] = batch["speaker_id"]
    return batch

# Apply preprocessing
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"], batch_size=8, num_proc=4, batched=True)

# Model
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xls-r-300m", num_labels=dataset["train"].features["speaker_id"].num_classes)

# Define Training Arguments
training_args = TrainingArguments(
  output_dir="./results",
  group_by_length=True,
  per_device_train_batch_shift_size=16,
  evaluation_strategy="steps",
  num_train_epochs=3,
  save_steps=500,
  eval_steps=500,
  logging_steps=10,
  learning_rate=1e-4,
  save_total_limit=2,
)

# Trainer with a custom compute_loss function
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = cross_entropy(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# Define Trainer
trainer = CustomTrainer(
  model=model,
  args=training_args,
  train_dataset=dataset["train"],
  eval_dataset=dataset["test"],
  tokenizer=feature_extractor,
)

# Start training
trainer.train()


fine tuning on our own curated dataset

In [None]:
import os
import sys
import torch
import librosa
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from transformers import Wav2Vec2Processor, Wav2Vec2Model, Trainer, TrainingArguments, TrainerCallback, Wav2Vec2FeatureExtractor
import math
from datasets import load_metric
from datetime import datetime
import torch.nn as nn

# Load the processor
processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")

# Define the custom dataset class using pandas
class LocalAudioDataset(Dataset):
    def __init__(self, csv_file, processor, subset, noise_factor=0.0):
        self.processor = processor
        self.data = pd.read_csv(csv_file)
        self.data = self.data[self.data['subset'] == subset]
        self.speaker_ids = {label: idx for idx, label in enumerate(self.data['label'].unique())}
        self.data['label'] = self.data['label'].map(self.speaker_ids)
        self.noise_factor = noise_factor
        
        print(f"Loaded {len(self.speaker_ids)} speakers: {self.speaker_ids}")
        print(f"Total files in {subset}: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx, retry_count=0):
        file_path = self.data.iloc[idx]['path']
        label = self.data.iloc[idx]['label']
        
        try:
            audio, sr = librosa.load(file_path, sr=16000)
            audio = librosa.to_mono(audio)
            audio = self._pad_or_truncate(audio, max_length=16000)
            if self.noise_factor > 0:
                audio = self._add_noise(audio)
            input_values = self.processor(audio, sampling_rate=16000, return_tensors="pt").input_values.squeeze(0)
            return {"input_values": input_values, "labels": label}
        except Exception as e:
            if retry_count < 3:  # Retry up to 3 times
                return self.__getitem__((idx + 1) % len(self), retry_count + 1)
            else:
                print(f"Error loading {file_path}: {e}", file=sys.stderr)
                raise e  # Raise exception if retry limit is reached

    def _pad_or_truncate(self, audio, max_length):
        if len(audio) < max_length:
            pad_size = max_length - len(audio)
            audio = np.pad(audio, (0, pad_size), 'constant', constant_values=(0, 0))
        else:
            audio = audio[:max_length]
        return audio

    def _add_noise(self, audio):
        noise = np.random.randn(len(audio))
        augmented_audio = audio + self.noise_factor * noise
        augmented_audio = augmented_audio.astype(type(audio[0]))
        return augmented_audio

# Paths to dataset CSV file
csv_file = 'dataset_large.csv'
train_dataset = LocalAudioDataset(csv_file, processor, 'train')
validate_dataset = LocalAudioDataset(csv_file, processor, 'validate')
test_dataset = LocalAudioDataset(csv_file, processor, 'test')

num_speakers = len(train_dataset.speaker_ids)
print(f"Number of unique speakers: {num_speakers}")

print(f"Labels in train dataset: {train_dataset.data['label'].tolist()}")
print(f"Labels in test dataset: {test_dataset.data['label'].tolist()}")

# Define a custom classification head on top of the base Wav2Vec2 model
class Wav2Vec2ClassificationHead(nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        self.dropout = nn.Dropout(config.hidden_dropout)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take the mean of the hidden states of the first token
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

# Define the full model by combining Wav2Vec2Model with the classification head
class Wav2Vec2ForCustomClassification(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-300m")
        self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2.config, num_labels)

    def forward(self, input_values, attention_mask=None, labels=None):
        outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        logits = self.classifier(hidden_states)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
        
        return (loss, logits) if loss is not None else logits

# Instantiate the model with the custom classification head
model = Wav2Vec2ForCustomClassification(num_labels=num_speakers)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

def validate_labels(dataset):
    for item in dataset:
        label = item['labels']
        if label >= num_speakers or label < 0:
            print(f"Invalid label {label} for item: {item}")
            raise ValueError(f"Invalid label {label} found in dataset.")
    print("All labels are valid.")

validate_labels(train_dataset)
validate_labels(validate_dataset)
validate_labels(test_dataset)

batch_size = 8
steps_per_epoch = math.ceil(len(train_dataset) / batch_size)
logging_steps = steps_per_epoch // 5
eval_steps = steps_per_epoch // 5

accuracy_metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

log_dir = "/home/rag/experimental_trial/results/training_logs"
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"training_logxlsr_finetuning_{timestamp}.csv")
with open(log_file, "w") as f:
    f.write("Timestamp,Step,Training Loss,Validation Loss,Accuracy\n")

class SaveMetricsCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            with open(log_file, "a") as f:
                timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                step = state.global_step
                training_loss = logs.get("loss", "")
                validation_loss = logs.get("eval_loss", "")
                accuracy = logs.get("eval_accuracy", "")
                f.write(f"{timestamp},{step},{training_loss},{validation_loss},{accuracy}\n")

class EarlyStoppingCallback(TrainerCallback):
    def __init__(self, early_stopping_patience=100, early_stopping_threshold=0.0):
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping_threshold = early_stopping_threshold
        self.best_metric = None
        self.patience_counter = 0

    def on_evaluate(self, args, state, control, **kwargs):
        metric = kwargs.get("metrics", {}).get("eval_loss")
        if metric is None:
            return
        
        if self.best_metric is None or metric < self.best_metric - self.early_stopping_threshold:
            self.best_metric = metric
            self.patience_counter = 0
        else:
            self.patience_counter += 1
        
        if self.patience_counter >= self.early_stopping_patience:
            print(f"Early stopping at step {state.global_step}")
            control.should_training_stop = True

# Ensure 'no_cuda' parameter aligns with device availability
training_args = TrainingArguments(
    output_dir="./results",
    group_by_length=True,
    per_device_train_batch_size=batch_size,
    evaluation_strategy="steps",
    num_train_epochs=100,
    save_steps=logging_steps,
    eval_steps=eval_steps,
    logging_steps=logging_steps,
    learning_rate=5e-6,
    save_total_limit=2,
    no_cuda=not torch.cuda.is_available(),
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,  # lower eval_loss is better
    save_strategy="steps"  # Save checkpoints every `save_steps`
)

# Add early stopping callback to the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validate_dataset,
    tokenizer=processor,
    compute_metrics=compute_metrics,
    callbacks=[SaveMetricsCallback(), EarlyStoppingCallback()]  # Include early stopping
)

# Train and evaluate
trainer.train()

metrics = trainer.evaluate(test_dataset)

print(f"Test set evaluation metrics: {metrics}")
print("Training and evaluation completed successfully!")

best_model_dir = "./results/best_model_xlsr_finetuning"
os.makedirs(best_model_dir, exist_ok=True)

trainer.save_model(best_model_dir)
processor.save_pretrained(best_model_dir)

print(f"Best model saved to {best_model_dir}")


now we extract hidden states

In [None]:
import os
import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2FeatureExtractor
from tqdm import tqdm
import librosa
from safetensors.torch import load_file as safe_load
from torch import nn

# Initialize the processor and model for xlsr
processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")
finetuned_model_path = "/home/rag/experimental_trial/results/best_model_xlsr_finetuning/model.safetensors"

# Define a custom classification head on top of the base Wav2Vec2 model
class Wav2Vec2ClassificationHead(nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        self.dropout = nn.Dropout(config.hidden_dropout)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take the mean of the hidden states of the first token
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

# Define the full model by combining Wav2Vec2Model with the classification head
class Wav2Vec2ForCustomClassification(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-300m", output_hidden_states=True)
        self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2.config, num_labels)

    def forward(self, input_values, attention_mask=None, labels=None):
        outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
        hidden_states = outputs.hidden_states
        logits = self.classifier(hidden_states[-1])
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
        
        return (loss, logits, hidden_states) if loss is not None else (logits, hidden_states)

# Instantiate the model with the custom classification head
model = Wav2Vec2ForCustomClassification(num_labels=111)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = safe_load(finetuned_model_path)
model.load_state_dict(state_dict)
model.to(device)

def check_directories_exist(directory, layer_indices):
    """Prüft, ob die benötigten Verzeichnisse für jede Schicht bereits existieren."""
    all_exist = True
    for index in layer_indices:
        layer_dir = os.path.join(directory, f"layer_{index}")
        if not os.path.exists(layer_dir):
            all_exist = False
            break
    return all_exist

def load_audio_files(directory, layer_indices=[-1]):
    """Lädt alle MP3-Dateien im angegebenen Verzeichnis und extrahiert die Repräsentationen aus den spezifizierten Schichten."""
    for filename in tqdm(os.listdir(directory)):
        if filename.endswith(".mp3"):
            file_path = os.path.join(directory, filename)
            audio, sr = librosa.load(file_path, sr=16000)
            inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
            input_values = inputs["input_values"].to(device)
            
            with torch.no_grad():
                logits, hidden_states = model(input_values)
                for index in layer_indices:
                    hidden_state = hidden_states[index]
                    # creating sub directory for each layer in speaker directory
                    layer_dir = os.path.join(directory, f"layer_{index}")
                    os.makedirs(layer_dir, exist_ok=True)
                    save_path = os.path.join(layer_dir, f"{os.path.splitext(filename)[0]}_layer_{index}.npy")
                    np.save(save_path, hidden_state.cpu().numpy())

def process_audio_directory(base_directory, layer_indices=range(25)):
    """Verarbeitet Audio-Dateien in den angegebenen Verzeichnissen, falls die Ziellayer-Verzeichnisse noch nicht existieren."""
    for d in os.listdir(base_directory):
        dir_path = os.path.join(base_directory, d)
        if os.path.isdir(dir_path) and not check_directories_exist(dir_path, layer_indices):
            load_audio_files(dir_path, layer_indices)

directory_path = os.path.expanduser("/home/rag/experimental_trial/data/all_speakers_xlrs_finetuned")

process_audio_directory(directory_path)


# now we use optuna to optimize the number of parameter used 

In [None]:
import os
import sys
import torch
import librosa
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from transformers import Wav2Vec2Processor, Wav2Vec2Model, Trainer, TrainingArguments, TrainerCallback, Wav2Vec2FeatureExtractor, Wav2Vec2Config
import math
from datasets import load_metric
from datetime import datetime
import torch.nn as nn
import optuna

# Load the processor
processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")

# Define the custom dataset class using pandas
class LocalAudioDataset(Dataset):
    def __init__(self, csv_file, processor, subset, noise_factor=0.0):
        self.processor = processor
        self.data = pd.read_csv(csv_file)
        self.data = self.data[self.data['subset'] == subset]
        self.speaker_ids = {label: idx for idx, label in enumerate(self.data['label'].unique())}
        self.data['label'] = self.data['label'].map(self.speaker_ids)
        self.noise_factor = noise_factor
        
        print(f"Loaded {len(self.speaker_ids)} speakers: {self.speaker_ids}")
        print(f"Total files in {subset}: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx, retry_count=0):
        file_path = self.data.iloc[idx]['path']
        label = self.data.iloc[idx]['label']
        
        try:
            audio, sr = librosa.load(file_path, sr=16000)
            audio = librosa.to_mono(audio)
            audio = self._pad_or_truncate(audio, max_length=16000)
            if self.noise_factor > 0:
                audio = self._add_noise(audio)
            input_values = self.processor(audio, sampling_rate=16000, return_tensors="pt").input_values.squeeze(0)
            return {"input_values": input_values, "labels": label}
        except Exception as e:
            if retry_count < 3:  # Retry up to 3 times
                return self.__getitem__((idx + 1) % len(self), retry_count + 1)
            else:
                print(f"Error loading {file_path}: {e}", file=sys.stderr)
                raise e  # Raise exception if retry limit is reached

    def _pad_or_truncate(self, audio, max_length):
        if len(audio) < max_length:
            pad_size = max_length - len(audio)
            audio = np.pad(audio, (0, pad_size), 'constant', constant_values=(0, 0))
        else:
            audio = audio[:max_length]
        return audio

    def _add_noise(self, audio):
        noise = np.random.randn(len(audio))
        augmented_audio = audio + self.noise_factor * noise
        augmented_audio = augmented_audio.astype(type(audio[0]))
        return augmented_audio

# Paths to dataset CSV file
csv_file = 'dataset_large.csv'
train_dataset = LocalAudioDataset(csv_file, processor, 'train')
validate_dataset = LocalAudioDataset(csv_file, processor, 'validate')
test_dataset = LocalAudioDataset(csv_file, processor, 'test')

num_speakers = len(train_dataset.speaker_ids)
print(f"Number of unique speakers: {num_speakers}")

print(f"Labels in train dataset: {train_dataset.data['label'].tolist()}")
print(f"Labels in test dataset: {test_dataset.data['label'].tolist()}")

# Define a custom classification head on top of the base Wav2Vec2 model
class Wav2Vec2ClassificationHead(nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        self.dropout = nn.Dropout(config.hidden_dropout)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take the mean of the hidden states of the first token
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

# Define the full model by combining Wav2Vec2Model with the classification head
class Wav2Vec2ForCustomClassification(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-300m")
        self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2.config, num_labels)

    def forward(self, input_values, attention_mask=None, labels=None):
        outputs = self.wav2vec2(input_values, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        logits = self.classifier(hidden_states)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
        
        return (loss, logits) if loss is not None else logits

# Instantiate the model with the custom classification head
model = Wav2Vec2ForCustomClassification(num_labels=num_speakers)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

def validate_labels(dataset):
    for item in dataset:
        label = item['labels']
        if label >= num_speakers or label < 0:
            print(f"Invalid label {label} for item: {item}")
            raise ValueError(f"Invalid label {label} found in dataset.")
    print("All labels are valid.")

validate_labels(train_dataset)
validate_labels(validate_dataset)
validate_labels(test_dataset)

batch_size = 8
steps_per_epoch = math.ceil(len(train_dataset) / batch_size)
logging_steps = steps_per_epoch // 5
eval_steps = steps_per_epoch // 5

accuracy_metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

log_dir = "/home/rag/experimental_trial/results/training_logs"
os.makedirs(log_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"training_logxlsr_finetuning_optimizing_layers_{timestamp}.csv")
with open(log_file, "w") as f:
    f.write("Timestamp,Step,Training Loss,Validation Loss,Accuracy\n")

class SaveMetricsCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            with open(log_file, "a") as f:
                timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                step = state.global_step
                training_loss = logs.get("loss", "")
                validation_loss = logs.get("eval_loss", "")
                accuracy = logs.get("eval_accuracy", "")
                f.write(f"{timestamp},{step},{training_loss},{validation_loss},{accuracy}\n")

class EarlyStoppingCallback(TrainerCallback):
    def __init__(self, early_stopping_patience=100, early_stopping_threshold=0.0):
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping_threshold = early_stopping_threshold
        self.best_metric = None
        self.patience_counter = 0

    def on_evaluate(self, args, state, control, **kwargs):
        metric = kwargs.get("metrics", {}).get("eval_loss")
        if metric is None:
            return
        
        if self.best_metric is None or metric < self.best_metric - self.early_stopping_threshold:
            self.best_metric = metric
            self.patience_counter = 0
        else:
            self.patience_counter += 1
        
        if self.patience_counter >= self.early_stopping_patience:
            print(f"Early stopping at step {state.global_step}")
            control.should_training_stop = True

# Define the Optuna objective function
def objective(trial):
    # Suggest the number of layers
    num_layers = trial.suggest_int('num_layers', 1, 24)

    # Load the model configuration with the suggested number of layers
    config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-xls-r-300m", num_labels=num_speakers)
    config.num_hidden_layers = num_layers

    # Instantiate the model with the custom classification head
    model = Wav2Vec2ForCustomClassification(num_labels=num_speakers)
    model.wav2vec2.encoder.layers = model.wav2vec2.encoder.layers[:num_layers]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Ensure 'no_cuda' parameter aligns with device availability
    training_args = TrainingArguments(
        output_dir="./results",
        group_by_length=True,
        per_device_train_batch_size=batch_size,
        evaluation_strategy="steps",
        num_train_epochs=10,  # Use 10 epochs
        save_steps=logging_steps,
        eval_steps=eval_steps,
        logging_steps=logging_steps,
        learning_rate=5e-6,  # Fixed learning rate
        save_total_limit=2,
        no_cuda=not torch.cuda.is_available(),
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        save_strategy="steps"
    )

    # Add early stopping callback to the trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=validate_dataset,
        tokenizer=processor,
        compute_metrics=compute_metrics,
        callbacks=[SaveMetricsCallback(), EarlyStoppingCallback()]
    )

    # Train and evaluate
    trainer.train()
    
    metrics = trainer.evaluate(validate_dataset)
    
    # Return the evaluation loss for Optuna to minimize
    return metrics["eval_loss"]

# Create an Optuna study and optimize the objective function
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=25)

# Print the best hyperparameters found
print(f"Best hyperparameters: {study.best_params}")

# Save the best model and processor
best_model_dir = "./results/best_model_xlsr_finetuning_optuna_layer_optimized"
os.makedirs(best_model_dir, exist_ok=True)

trainer.save_model(best_model_dir)
processor.save_pretrained(best_model_dir)

print(f"Best model saved to {best_model_dir}")
