<a href="https://colab.research.google.com/github/vifirsanova/phat-llm/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. **Install Necessary Libraries**:

In [None]:
!pip install transformers datasets torch

2. **Import Required Libraries**:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from datasets import load_dataset, DatasetDict
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

3. **Set Device**:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4. **Load Pre-trained Model and Processor**:

In [None]:
model_name = "facebook/wav2vec2-base-960h"
processor = Wav2Vec2Processor.from_pretrained(model_name)

5. **Define Data Preprocessing Functions**:

In [None]:
def preprocess_function(batch, task):
    audio = batch["audio"]
    input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
    if task == "ipa":
        with processor.as_target_processor():
            labels = processor(batch["ipa_transcription"], return_tensors="pt").input_ids
    elif task == "prosody":
        labels = torch.tensor(batch["prosody_labels"])
    elif task == "non_verbal":
        labels = torch.tensor(batch["non_verbal_labels"])
    return {"input_values": input_values, "labels": labels}

6. **Load and Preprocess Dataset**:

In [None]:
def load_and_preprocess_dataset(dataset_name, task):
    dataset = load_dataset(dataset_name)
    dataset = dataset.map(lambda batch: preprocess_function(batch, task), batched=True)
    return dataset

7. **Define Training Arguments**:

In [None]:
def get_training_args(output_dir, num_train_epochs, batch_size, learning_rate):
    return {
        "output_dir": output_dir,
        "num_train_epochs": num_train_epochs,
        "per_device_train_batch_size": batch_size,
        "per_device_eval_batch_size": batch_size,
        "evaluation_strategy": "epoch",
        "logging_dir": "./logs",
        "learning_rate": learning_rate,
    }

8. **Define Evaluation Metrics**:

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

9. **Fine-Tuning Function**:

In [None]:
def fine_tune_model(dataset, task, model_class, output_dir, num_train_epochs, batch_size, learning_rate):
    model = model_class.from_pretrained(model_name).to(device)

    training_args = get_training_args(output_dir, num_train_epochs, batch_size, learning_rate)

    from transformers import Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        data_collator=processor,
        compute_metrics=compute_metrics if task != "ipa" else None,
    )

    trainer.train()

    model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)

    return model

10. **Example Usage**:

In [None]:
# 1. Phonetic Transcription (IPA Symbols)
dataset_name = "your_ipa_transcription_dataset"
ipa_dataset = load_and_preprocess_dataset(dataset_name, "ipa")
ipa_model = fine_tune_model(ipa_dataset, "ipa", Wav2Vec2ForCTC, "./results/ipa", 10, 8, 5e-5)

# 2. Prosody Analysis
dataset_name = "your_prosody_dataset"
prosody_dataset = load_and_preprocess_dataset(dataset_name, "prosody")
prosody_model = fine_tune_model(prosody_dataset, "prosody", Wav2Vec2ForSequenceClassification, "./results/prosody", 10, 8, 5e-5)

# 3. Non-Verbal Marker Annotation
dataset_name = "your_non_verbal_dataset"
non_verbal_dataset = load_and_preprocess_dataset(dataset_name, "non_verbal")
non_verbal_model = fine_tune_model(non_verbal_dataset, "non_verbal", Wav2Vec2ForSequenceClassification, "./results/non_verbal", 10, 8, 5e-5)