In [1]:
import torch.optim as optim
from IPython.display import clear_output
from dashboard_functions import MultiSpeakerDashboard
# Progress bar
from tqdm.auto import tqdm
from functionsV3 import *
import logging
import warnings
from collections import defaultdict
import pandas as pd
from transformers import (
    AutoImageProcessor,
    DinatForImageClassification,
    TrainingArguments,
    get_scheduler,
    BertModel,
    AutoTokenizer,
)
logging.getLogger().addHandler(logging.NullHandler())
logging.getLogger("natten.functional").setLevel(logging.ERROR)
warnings.filterwarnings("ignore", category=UserWarning)



  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
EMOTIONS = {
    0: 'neutral',
    1: 'happy',
    2: 'sad',
    3: 'angry',
}
Map2Num = {
    'neutral': 0,
    'happy': 1,
    'sad': 2,
    'angry': 3,
}

num_labels = 4

In [3]:

base_column  = "label"
dataset_train = "cairocode/IEMO_Mel_6"
dataset_val = "cairocode/MSPI_Mel6"
ds_tr = os.path.split(dataset_train)[1]
ds_vl = os.path.split(dataset_val)[1]


model_path = "shi-labs/dinat-mini-in1k-224"  # For processor loading if needed
checkpoint_path = "/home/rml/Documents/pythontest/Trained_Models/curr_V2/EmoDom/20250310_1/best_model.pt" #None #"./EmoDom/best_model.pt"
pretrain_model = model_path
bert_model_name = "bert-base-uncased"
BATCH_SIZE = 60

fourclass = True
Speaker_Disentanglement = True
pretrain = True
column = "label"

alpha = 1.6
beta = 0.0008
gamma = 1


base_dir = "/home/rml/Documents/pythontest/Trained_Models/LOSO/"+ ds_tr
if pretrain == True and Speaker_Disentanglement == True:
    base_dir = os.path.join(base_dir, "PRSD")

elif pretrain == True :
    base_dir = os.path.join(base_dir, "PR")
elif Speaker_Disentanglement == True :
    base_dir = os.path.join(base_dir, "SD")
else:
    base_dir = os.path.join(base_dir, "OG")

base_dir = create_unique_output_dir(base_dir)
print(base_dir)
os.makedirs(base_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

/home/rml/Documents/pythontest/Trained_Models/LOSO/IEMO_Mel_6/PRSD/20250317_5


In [4]:
def filter_m_examples(example):
    return example["label"] != 4 and example["label"] != 5 and isinstance(example['transcript'], str)

train_d0 = load_dataset(dataset_train, split='train')
val_dataset0 = load_dataset(dataset_val, split='train')


train_d0  = train_d0.filter(filter_m_examples)
val_dataset0  = val_dataset0.filter(filter_m_examples)

Xcorp_dataloader = DataLoader(
    val_dataset0,
    batch_size=BATCH_SIZE,
    collate_fn=lambda examples: collate_fn_reg(examples, column=column),
)
spkrs = [sample['speakerID'] for sample in train_d0]

val_dataset0.set_transform(val_transforms)
unique_speakers = list(set(spkrs))
print(unique_speakers)

all_y_true = defaultdict(list)
all_y_pred = defaultdict(list)

total_results = pd.DataFrame()

xcorp_results = {
    'y_true': [],
    'y_pred': {f'run_{i}': [] for i in range(10)}
}

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]


In [None]:
dashboard = MultiSpeakerDashboard(base_dir=base_dir, port=8000, auto_open=True)

for i in range (0,len(unique_speakers), 1):
    speakers = [unique_speakers[i]]  #    speakers = [937+i]
    num = speakers[0] -1 #-1
    print(f"NUM = {num} __ SPKRS {speakers}")
    print(f"\n {'#'*120}")
    print(f"                                          STARTING SPEAKER {num}                                                     ")
    print(f"\n {'#'*120}")

    dashboard.start_speaker_run(speaker_id=num, speaker_name=f"Speaker {num}")

    new_model_path = os.path.join(base_dir, str(num))
    os.makedirs(new_model_path, exist_ok=True)

    # Create the test split
    test_dataset = train_d0.filter(lambda x: x['speakerID'] in speakers).filter(filter_m_examples)

    # Create the remaining data
    train_set = train_d0.filter(lambda x: x['speakerID'] not in speakers).filter(filter_m_examples)
    train_set = balance_dataset(train_set, label_column="label", seed=42)

    # train_set = train_set.train_test_split(test_size = 0.2)
    train_dataset = train_set #['train']
    val_dataset = test_dataset #train_set['test']

    train_dataset.set_transform(train_transforms)
    val_dataset.set_transform(val_transforms)
    test_dataset.set_transform(val_transforms)
    train_sampler = CustomSampler(train_dataset, "SpeakerID")

    if Speaker_Disentanglement == True:
        train_loader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=BATCH_SIZE,
        # shuffle=True,
        collate_fn=lambda examples: collate_fn_reg(examples, column=column),
        )
    else:
        train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=lambda examples: collate_fn_reg(examples, column=column),
        )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=lambda examples: collate_fn_reg(examples, column=column),
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=lambda examples: collate_fn_reg(examples, column=column),
    )



    class_weights = calculate_class_weights(train_dataset)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    # Load Models
    image_model = DinatForImageClassification.from_pretrained(
        pretrain_model,
        num_labels=num_labels,
        ignore_mismatched_sizes=True,
        problem_type="single_label_classification",
    ).to(device)

    processor = DinatForImageClassification.from_pretrained(model_path).to(device)

    tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
    bert_model = BertModel.from_pretrained(bert_model_name).to(device)


    # ----------------------------------------------------------------------
    # Initialize the Combined Regression Model
    # ----------------------------------------------------------------------
    unfozen_layers = [10,11]
    next_layer_to_unfreeze = unfozen_layers[0]-1

    model = CombinedModelsBi(
        image_model=image_model,
        bert_model=bert_model,
        image_feature_dim=512,
        bert_embedding_dim=768,
        combined_dim=512,
        num_labels=num_labels,
        unfrozen_layers = unfozen_layers

    ).to(device)

    if pretrain == True:
        checkpoint = torch.load(checkpoint_path, map_location=device)

        include_keyword = "model"
        exclude_keys = {
            "image_model.classifier.weight",
            "image_model.classifier.bias",
            "fc3"
        }

        filtered_checkpoint = {
            key: value for key, value in checkpoint.items()
            if include_keyword in key and key not in exclude_keys
        }
        model.load_state_dict(filtered_checkpoint, strict=False)

    training_args = TrainingArguments(
        output_dir="./logs",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=1e-5,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=25,
        weight_decay=0.01,
        load_best_model_at_end=True
        )
    cecc_loss= BalancedCrossEntropyWithContrastiveLoss(
        num_classes=num_labels,  # Number of emotion classes
        feature_dim=512,  # Feature dimension from fc2
        alpha=alpha,  # Adjust weights for CE vs Contrastive-Center
        beta=beta,
        gamma = gamma
    )


    optimizer = optim.AdamW([
    {'params': model.parameters(), 'weight_decay': training_args.weight_decay},
    {'params': cecc_loss.parameters(), 'lr': 0.0001, 'weight_decay': 0}
    ], 
        lr=training_args.learning_rate
    )

    num_training_steps = len(train_loader) * training_args.num_train_epochs
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=training_args.warmup_steps,
        num_training_steps=num_training_steps,
    )

    
    num_epochs = training_args.num_train_epochs
    patience = 6
    best_val_accuracy = 0
    patience_counter = 0

    train_losses, val_losses, epochs_list = [], [], []

    best_model_path = os.path.join(new_model_path, "best_model.pt")

    for epoch in range(num_epochs):
        batch_idx = 0

        model.train()
        train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for batch in progress_bar:
            batch_idx+=1
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                pixel_values=pixel_values,
                bert_input_ids=input_ids,
                bert_attention_mask=attention_mask
            )
            logits = outputs["logits"]
            features =  outputs["combined_features"]

            loss, weight_info = cecc_loss(logits,features, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            progress_bar.set_postfix({"Loss": loss.item()})
            if batch_idx % 5 == 0:
                dashboard.update_batch(
                    epoch=epoch,
                    batch=batch_idx+1,
                    train_loss=loss.item(),
                    total_batches=len(train_loader)
                )

        avg_train_loss = train_loss / len(train_loader)
        lr_scheduler.step()

        # Validation
        model.eval()
        val_loss = 0
        all_predictions, all_labels = [], []

        with torch.no_grad():
            for batch in val_loader:
                pixel_values = batch["pixel_values"].to(device)
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = model(
                    pixel_values=pixel_values,
                    bert_input_ids=input_ids,
                    bert_attention_mask=attention_mask
                )
                logits = outputs["logits"]
                features =  outputs["combined_features"]

                loss, weight_info = cecc_loss(logits,features, labels)

                # loss = focal_loss(logits, labels)
                val_loss += loss.item()

                predictions = torch.argmax(logits, dim=-1)
                all_predictions.extend(predictions.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        accuracy = accuracy_score(all_labels, all_predictions)
        uar = recall_score(all_labels, all_predictions, average="macro")
        f1 = f1_score(all_labels, all_predictions, average="macro")
        per_class_recall = recall_score(all_labels, all_predictions, average=None)
        uar_std = np.std(per_class_recall)

        gamma_comp = 1.5
        comparison_metric = uar / (1 + gamma_comp * uar_std)

        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        epochs_list.append(epoch + 1)

        print(
            f"Epoch {epoch+1}/{num_epochs} - Training Loss: {avg_train_loss:.4f}"
            f"Validation Loss: {avg_val_loss:.4f}, "
            f"Accuracy: {accuracy:.4f}, UAR: {uar:.4f}, F1: {f1:.4f}, UAR STD: {uar_std} "
            f"Comparison metric: {comparison_metric}"
            f"\nCE weight: {weight_info['weight_ce']:.6f} (log var: {weight_info['log_var_ce']:.4f})"
            f"Contrastive weight: {weight_info['weight_contrastive']:.6f} (log var: {weight_info['log_var_contrastive']:.4f})"
            f"Balance weight: {weight_info['weight_balance']:.6f} (log var: {weight_info['log_var_balance']:.4f})"
        )
        dashboard.update(
            epoch=epoch+1,
            train_loss=avg_train_loss,
            val_loss=avg_val_loss,
            accuracy=accuracy,
            uar=uar,
            f1=f1
        )
    
        # Early Stopping based on Accuracy
        if accuracy > best_val_accuracy:
            best_val_accuracy = accuracy
            patience_counter = 0
            torch.save(model.state_dict(), best_model_path)
            torch.save(model.image_model.state_dict(), "fine_tuned_image_model.pth")
            best_epoch = epoch
            print("Validation accuracy improved. Best model saved.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

        if (epoch + 1) % 5 == 0 and next_layer_to_unfreeze >= 0:
                print(f"Unfreezing BERT layer {next_layer_to_unfreeze}")
                unfreeze_bert_layer(model.bert_model, next_layer_to_unfreeze)
                
                next_layer_to_unfreeze -= 1


    # Load Best Model
    print("Loading best model for final evaluation.")
    model.load_state_dict(torch.load(best_model_path))
    model.to(device)

    ##############################################################################
    # Test Evaluation
    ##############################################################################
    print("\nStarting Test Evaluation...")
    model.eval()
    test_loss = 0
    all_test_predictions, all_test_labels = [], []

    with torch.no_grad():
        test_progress_bar = tqdm(test_loader, desc="Testing", leave=False)
        for batch in test_progress_bar:
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                pixel_values=pixel_values,
                bert_input_ids=input_ids,
                bert_attention_mask=attention_mask
            )
            logits = outputs["logits"]

            loss = F.cross_entropy(logits, labels)
            test_loss += loss.item()

            predictions = torch.argmax(logits, dim=-1)
            all_test_predictions.extend(predictions.cpu().numpy())
            all_test_labels.extend(labels.cpu().numpy())

    avg_test_loss = test_loss / len(test_loader)
    test_accuracy = accuracy_score(all_test_labels, all_test_predictions)
    test_uar = recall_score(all_test_labels, all_test_predictions, average="macro")
    test_f1 = f1_score(all_test_labels, all_test_predictions, average="macro")

    metrics_str = (
        f"Test Loss: {avg_test_loss:.4f}, "
        f"Accuracy: {test_accuracy:.4f}, "
        f"UAR: {test_uar:.4f}, "
        f"F1: {test_f1:.4f}"
    )
    print(metrics_str)



    plot_and_save_confusion_matrix(all_test_labels, all_test_predictions, Map2Num, new_model_path, epoch=None, filename = f"{ds_tr}_{test_accuracy*100:.2f}_Acc_{test_uar*100:.2f}_UAR.png")
    
    overall_accuracy = test_accuracy
    overall_UAR = test_uar
    overall_F1 = test_f1
    full_accuracy = test_accuracy

    # Save final metrics
    output_file = os.path.join(new_model_path, "metrics.txt")
    with open(output_file, "w") as f:
        f.write(f"Overall F1 Score: {overall_F1:.4f}\n")
        f.write(f"Overall Accuracy: {overall_accuracy:.4f}\n")
        f.write(f"Full Accuracy: {full_accuracy:.4f}\n")
        f.write(f"Overall UAR: {overall_UAR:.4f}\n")
        f.write(f" Class Mapping :{EMOTIONS}\n")
        f.write(f" Best Epoch :{best_epoch}\n")
        f.write(f" Train Dataset :{ds_tr}\n")
        f.write(f" Alpha :{alpha}\n")
        f.write(f" Beta:{beta}\n")
        f.write(f" Gamma:{gamma}\n")
        f.write(f" Class Weights :{class_weights}\n")
    print(f"Metrics saved to {output_file}")

    dashboard.export_results(new_model_path)
    #############################################################################
    # CROSS CORPUS Evaluation
    ##############################################################################
    print("\nStarting Cross Corpus Evaluation...")
    model.eval()
    test_loss = 0
    X_preds, X_true = [], []

    with torch.no_grad():
        test_progress_bar = tqdm(Xcorp_dataloader, desc="Testing", leave=False)
        for batch in test_progress_bar:
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                pixel_values=pixel_values,
                bert_input_ids=input_ids,
                bert_attention_mask=attention_mask
            )
            logits = outputs["logits"]

            loss = F.cross_entropy(logits, labels)
            test_loss += loss.item()

            predictions = torch.argmax(logits, dim=-1)
            X_preds.extend(predictions.cpu().numpy())
            X_true.extend(labels.cpu().numpy())

    Xavg_test_loss = test_loss / len(test_loader)
    Xtest_accuracy = accuracy_score(X_true, X_preds)
    Xtest_uar = recall_score(X_true, X_preds, average="macro")
    Xtest_f1 = f1_score(X_true, X_preds, average="macro")

    if i == 0:
        xcorp_results['y_true'].extend(X_true)

    # Store y_pred for each run
    xcorp_results['y_pred'][f'run_{i}'].extend(X_preds)


    metrics_str = (
        f"Test Loss: {Xavg_test_loss:.4f}, "
        f"Accuracy: {Xtest_accuracy:.4f}, "
        f"UAR: {Xtest_uar:.4f}, "
        f"F1: {Xtest_f1:.4f}"
    )
    print(metrics_str)
    
    plot_and_save_confusion_matrix(X_true, X_preds, Map2Num, new_model_path, epoch=None, filename = f"{ds_vl}_{Xtest_accuracy*100:.2f}_Acc_{Xtest_uar*100:.2f}_UAR.png")
    
    overall_accuracy = Xtest_accuracy
    overall_UAR = Xtest_uar
    overall_F1 = Xtest_f1
    full_accuracy = Xtest_accuracy
    new_row = {f'{ds_tr}_ACC': test_accuracy*100, f'{ds_tr}_UAR': test_uar*100, f'{ds_vl}_ACC': Xtest_accuracy*100,f'{ds_vl}_UAR': Xtest_uar*100}
    print(f"\n {'-'*120}")
    print("\n\n", new_row, "\n")
    
    total_results = pd.concat([total_results, pd.DataFrame([new_row])], ignore_index=True)
    # Save final metrics
    output_file = os.path.join(new_model_path, f"{ds_vl}_metrics.txt")
    with open(output_file, "w") as f:
        f.write(f"Overall F1 Score: {overall_F1:.4f}\n")
        f.write(f"Overall Accuracy: {overall_accuracy:.4f}\n")
        f.write(f"Full Accuracy: {full_accuracy:.4f}\n")
        f.write(f"Overall UAR: {overall_UAR:.4f}\n")
        f.write(f" Class Mapping :{EMOTIONS}\n")
        f.write(f" Best Epoch :{best_epoch}\n")

    print(f"Metrics saved to {output_file}")

    all_y_true[ds_tr].extend(all_test_labels)
    all_y_pred[ds_tr].extend(all_test_predictions)
    all_y_true[ds_vl].extend(X_true)
    all_y_pred[ds_vl].extend(X_preds)


    torch.cuda.empty_cache()
    del model



Dashboard server started at http://localhost:8000/dashboard.html
NUM = 0 __ SPKRS [1]

 ########################################################################################################################
                                          STARTING SPEAKER 0                                                     

 ########################################################################################################################
Started tracking training for Speaker 0 (ID: 0)


In [None]:
import csv
import json
# If you want to get the final prediction based on majority voting across all runs
y_true = np.array(xcorp_results['y_true'])

final_prediction = np.array([np.bincount([xcorp_results['y_pred'][f'run_{i}'][j] for i in range(0,10,1)]).argmax() 
                             for j in range(len(y_true))])

final_accuracy = accuracy_score(y_true, final_prediction)
print(f"Final accuracy after majority voting: {final_accuracy}")
print(total_results)
avg_results = total_results.mean(numeric_only=True)

print("\nAVERAGE RESULTS\n", avg_results)

final_metrics = {}
for dataset in [ds_tr, ds_vl]:
    y_true = np.array(all_y_true[dataset])
    y_pred = np.array(all_y_pred[dataset])
    
    acc = accuracy_score(y_true, y_pred) * 100
    uar = recall_score(y_true, y_pred, average = "macro") * 100
    f1 = f1_score(y_true, y_pred, average="macro")
    final_metrics[f'{dataset}_ACC'] = acc
    final_metrics[f'{dataset}_UAR'] = uar
    final_metrics[f'{dataset}_f1'] = f1

print("\nFinal Metrics:")
print(final_metrics)

def save_model_info_to_txt(model_info, dirname, filename="results.txt"):
    """
    Save model information dictionary to a text file.
    
    Parameters:
    model_info (dict): Dictionary containing model information
    dirname (str): Directory path where to save the text file
    filename (str): Name of the output text file
    """
    # Create full file path
    filepath = os.path.join(dirname, filename)
    
    # Write information to text file
    with open(filepath, 'w') as f:
        for key, value in model_info.items():
            f.write(f"{key}: {value}\n\n")
    
    print(f"Model information written to {filepath}")

model_info = {
    "Pretrain_file": checkpoint_path,
    "Dataset Used": dataset_train,
    "Model Type": "DINAT-BERT",
    "Speaker Disentanglement": Speaker_Disentanglement,
    "Column Trained on": column,
    "Test SpeakerID": speakers,
    "Avg Results": avg_results, 
    "total results": final_metrics,
    "alpha": alpha,
    "beta": beta,
    "gamma": gamma,
}
save_model_info_to_txt(model_info, base_dir)
dashboard.finish()

In [None]:
#base_dir = "/home/rml/Documents/pythontest/Trained_Models/LOSO/IEMO_Mel_6/20250311_2"
new_model_path = os.path.join(base_dir, str("cross_corpus"))
os.makedirs(new_model_path, exist_ok=True)

# Create the test split
test_dataset = val_dataset0

train_dataset = train_d0 #['train']
val_dataset = test_dataset #train_set['test']

train_dataset.set_transform(train_transforms)
val_dataset.set_transform(val_transforms)
test_dataset.set_transform(val_transforms)
train_sampler = CustomSampler(train_dataset)

if Speaker_Disentanglement == True:
    train_loader = DataLoader(
    train_dataset,
    # sampler=train_sampler,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda examples: collate_fn_reg(examples, column=column),
    )
else:
    train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda examples: collate_fn_reg(examples, column=column),
    )
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=lambda examples: collate_fn_reg(examples, column=column),
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=lambda examples: collate_fn_reg(examples, column=column),
)



class_weights = calculate_class_weights(train_dataset)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Load Models
image_model = DinatForImageClassification.from_pretrained(
    pretrain_model,
    num_labels=num_labels,
    ignore_mismatched_sizes=True,
    problem_type="single_label_classification",
).to(device)

processor = DinatForImageClassification.from_pretrained(model_path).to(device)

tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertModel.from_pretrained(bert_model_name).to(device)


# ----------------------------------------------------------------------
# Initialize the Combined Regression Model
# ----------------------------------------------------------------------
unfozen_layers = [11]
next_layer_to_unfreeze = unfozen_layers[0]-1

model = CombinedModelsBi(
    image_model=image_model,
    bert_model=bert_model,
    image_feature_dim=512,
    bert_embedding_dim=768,
    combined_dim=512,
    num_labels=num_labels,
    unfrozen_layers = unfozen_layers
    ).to(device)

if pretrain == True:
    checkpoint = torch.load(checkpoint_path, map_location=device)

    include_keyword = "model"
    exclude_keys = {
        "image_model.classifier.weight",
        "image_model.classifier.bias",
        "fc3"
    }


    filtered_checkpoint = {
        key: value for key, value in checkpoint.items()
        if include_keyword in key and key not in exclude_keys
    }
    model.load_state_dict(filtered_checkpoint, strict=False)


optimizer = optim.AdamW([
    {'params': model.parameters(), 'weight_decay': training_args.weight_decay},
    {'params': cecc_loss.parameters(), 'lr': training_args.learning_rate *10, 'weight_decay': 0.1}
    ], 
        lr=training_args.learning_rate
    )

num_training_steps = len(train_loader) * training_args.num_train_epochs
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=training_args.warmup_steps,
    num_training_steps=num_training_steps,
)



num_epochs = training_args.num_train_epochs
patience = 12
best_val_accuracy = 0
patience_counter = 0

train_losses, val_losses, epochs_list = [], [], []

best_model_path = os.path.join(new_model_path, "best_model.pt")

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

    for batch in progress_bar:
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(
            pixel_values=pixel_values,
            bert_input_ids=input_ids,
            bert_attention_mask=attention_mask
        )
        logits = outputs["logits"]
        features =  outputs["combined_features"]

        loss , weight_info = cecc_loss(logits,features, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        progress_bar.set_postfix({"Loss": loss.item()})

    avg_train_loss = train_loss / len(train_loader)
    lr_scheduler.step()

    # Validation
    model.eval()
    val_loss = 0
    all_predictions, all_labels = [], []

    with torch.no_grad():
        for batch in val_loader:
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                pixel_values=pixel_values,
                bert_input_ids=input_ids,
                bert_attention_mask=attention_mask
            )
            logits = outputs["logits"]
            features =  outputs["combined_features"]

            loss, weight_info = cecc_loss(logits,features, labels)

            # loss = focal_loss(logits, labels)
            val_loss += loss.item()

            predictions = torch.argmax(logits, dim=-1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    accuracy = accuracy_score(all_labels, all_predictions)
    uar = recall_score(all_labels, all_predictions, average="macro")
    f1 = f1_score(all_labels, all_predictions, average="macro")
    per_class_recall = recall_score(all_labels, all_predictions, average=None)
    uar_std = np.std(per_class_recall)

    gamma_comp = 1.5
    comparison_metric = uar / (1 + gamma_comp * uar_std)

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    epochs_list.append(epoch + 1)

    print(
        f"Epoch {epoch+1}/{num_epochs} - Training Loss: {avg_train_loss:.4f}"
        f"Validation Loss: {avg_val_loss:.4f}, "
        f"Accuracy: {accuracy:.4f}, UAR: {uar:.4f}, F1: {f1:.4f}, UAR STD: {uar_std} "
        f"Comparison metric: {comparison_metric}"
    )

    # Early Stopping based on Accuracy
    if comparison_metric > best_val_accuracy:
        best_val_accuracy = comparison_metric
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        torch.save(model.image_model.state_dict(), "fine_tuned_image_model.pth")
        best_epoch = epoch
        print("Validation uar improved. Best model saved.")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

    if (epoch + 1) % 3 == 0 and next_layer_to_unfreeze >= 0:
            print(f"Unfreezing BERT layer {next_layer_to_unfreeze}")
            unfreeze_bert_layer(model.bert_model, next_layer_to_unfreeze)
            
            next_layer_to_unfreeze -= 1


# Load Best Model
print("Loading best model for final evaluation.")
model.load_state_dict(torch.load(best_model_path))
model.to(device)

##############################################################################
# Test Evaluation
##############################################################################
print("\nStarting Test Evaluation...")
model.eval()
test_loss = 0
all_test_predictions, all_test_labels = [], []

with torch.no_grad():
    test_progress_bar = tqdm(test_loader, desc="Testing", leave=False)
    for batch in test_progress_bar:
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(
            pixel_values=pixel_values,
            bert_input_ids=input_ids,
            bert_attention_mask=attention_mask
        )
        logits = outputs["logits"]

        loss = F.cross_entropy(logits, labels)
        test_loss += loss.item()

        predictions = torch.argmax(logits, dim=-1)
        all_test_predictions.extend(predictions.cpu().numpy())
        all_test_labels.extend(labels.cpu().numpy())

avg_test_loss = test_loss / len(test_loader)
test_accuracy = accuracy_score(all_test_labels, all_test_predictions)
test_uar = recall_score(all_test_labels, all_test_predictions, average="macro")
test_f1 = f1_score(all_test_labels, all_test_predictions, average="macro")

metrics_str = (
    f"Test Loss: {avg_test_loss:.4f}, "
    f"Accuracy: {test_accuracy:.4f}, "
    f"UAR: {test_uar:.4f}, "
    f"F1: {test_f1:.4f}"
)
plot_and_save_confusion_matrix(all_test_labels, all_test_predictions, Map2Num, new_model_path, epoch=None, filename = f"{ds_vl}_{test_accuracy*100:.2f}_Acc_{test_uar*100:.2f}_UAR.png")

overall_accuracy = test_accuracy
overall_UAR = test_uar
overall_F1 = test_f1
full_accuracy = test_accuracy

# Save final metrics
output_file = os.path.join(new_model_path, "metrics.txt")
with open(output_file, "w") as f:
    f.write(f"Overall F1 Score: {overall_F1:.4f}\n")
    f.write(f"Overall Accuracy: {overall_accuracy:.4f}\n")
    f.write(f"Full Accuracy: {full_accuracy:.4f}\n")
    f.write(f"Overall UAR: {overall_UAR:.4f}\n")
    f.write(f" Class Mapping :{EMOTIONS}\n")
    f.write(f" Best Epoch :{best_epoch}\n")
    f.write(f" Train Dataset :{ds_tr}\n")
    f.write(f" Alpha :{alpha}\n")
    f.write(f" Beta:{beta}\n")
    f.write(f" Gamma:{gamma}\n")
    f.write(f" Class Weights :{class_weights}\n")
print(f"Metrics saved to {output_file}")
