In [2]:
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
from datasets import load_dataset
from utils import *

In [3]:
model = TFAutoModel.from_pretrained("distilroberta-base")
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFRobertaModel: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing TFRobertaModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFRobertaModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFRobertaModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFRobertaModel for predictions without further training.


In [4]:
# Load the dataset
emotion_dataset = load_dataset("google-research-datasets/go_emotions", "simplified")

In [5]:
# Select a larger subset of examples for better training
small_train_dataset = emotion_dataset['train'].select(range(10000))
small_test_dataset = emotion_dataset['test'].select(range(1000))

# Create a new small dataset with the reduced splits
small_emotion_dataset = {
    'train': small_train_dataset,
    'test': small_test_dataset
}

In [6]:
emotions_id2label = {
    0: 'admiration',
    1: 'amusement',
    2: 'anger',
    3: 'annoyance',
    4: 'approval',
    5: 'caring',
    6: 'confusion',
    7: 'curiosity',
    8: 'desire',
    9: 'disappointment',
    10: 'disapproval',
    11: 'disgust',
    12: 'embarrassment',
    13: 'excitement',
    14: 'fear',
    15: 'gratitude',
    16: 'grief',
    17: 'joy',
    18: 'love',
    19: 'nervousness',
    20: 'optimism',
    21: 'pride',
    22: 'realization',
    23: 'relief',
    24: 'remorse',
    25: 'sadness',
    26: 'surprise',
    27: 'neutral'  # Last entry (no comma)
}

emotions_label2id = {v: k for k, v in emotions_id2label.items()}

# Print dataset info to verify we understand what we're working with
print("Sample of emotions dataset:")
print(small_emotion_dataset["train"][0])
print(f"Number of labels: {len(emotions_id2label)}")

Sample of emotions dataset:
{'text': "My favourite food is anything I didn't have to cook myself.", 'labels': [27], 'id': 'eebbqej'}
Number of labels: 28


In [10]:
# Cell source for ID: 4c515f62 (Simplified Logic)
MINORITY_THRESHOLD_PERCENT = 1.0
from collections import Counter
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
import re

print("Analyzing original training data for minority classes...")
original_train_labels = small_train_dataset['labels']
all_labels = [label for sublist in original_train_labels for label in sublist]
label_counts_original = Counter(all_labels)
total_samples_original = len(small_train_dataset)

minority_threshold_count = total_samples_original * (MINORITY_THRESHOLD_PERCENT / 100.0)
minority_classes = {
    label for label, count in label_counts_original.items()
    if count < minority_threshold_count
}
if minority_classes:
    print(f"Identified {len(minority_classes)} minority classes (frequency < {MINORITY_THRESHOLD_PERCENT}%, count < {minority_threshold_count:.0f}):")
    minority_class_names = {emotions_id2label.get(idx, f"Unknown({idx})") for idx in minority_classes}
    print(f"Minority Classes Indices: {minority_classes}")
    print(f"Minority Classes Names: {minority_class_names}")
else:
    print(f"No minority classes found with frequency < {MINORITY_THRESHOLD_PERCENT}%.")

print("Applying data augmentation to training set (this may take a while)...")

augmentation_model = "huihui-ai/Llama-3.2-1B-Instruct-abliterated"

augmentation_tokenizer = AutoTokenizer.from_pretrained(augmentation_model)
augmentation_generator = AutoModelForCausalLM.from_pretrained(
    augmentation_model,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

pipe = pipeline(
    "text-generation",
    model=augmentation_generator,
    tokenizer=augmentation_tokenizer,
    max_length=512,
    truncation=True,
)
# Ensure pad token ID exists
if pipe.tokenizer.pad_token_id is None:
    pipe.tokenizer.pad_token_id = pipe.tokenizer.eos_token_id

def augment_data(example, minority_classes_set=None):
    should_augment = False
    target_minority_label = None

    if minority_classes_set:
        example_labels_list = []
        if isinstance(example.get('labels'), list):
            example_labels_list = example['labels']
        elif isinstance(example.get('labels'), int):
            example_labels_list = [example['labels']]

        for lbl in example_labels_list:
            if lbl in minority_classes_set:
                should_augment = True
                target_minority_label = lbl
                break

    if should_augment and target_minority_label is not None:
        original_text = example['text']
        emotional_label_str = emotions_id2label[target_minority_label]

        prompt = f"Rephrase the following text to express the emotion '{emotional_label_str}' in 1-2 short sentences. Do not include lists or explanations. Original text: '{original_text}'. Rephrased text:"
        augmented_text = "" # Initialize augmented_text
        try:
            outputs = pipe(
                prompt,
                max_new_tokens=100, # Further reduced tokens
                do_sample=True,
                temperature=0.75, # Slightly increased temperature
                pad_token_id=pipe.tokenizer.pad_token_id,
            )
            generated_full_text = outputs[0]['generated_text']

            parts = generated_full_text.split(prompt, 1)
            if len(parts) > 1:
                augmented_text = parts[1].strip()
            else:
                 augmented_text = generated_full_text.strip()
                 # A basic check: if it still contains a significant part of the prompt, discard it.
                 if original_text[:20] in augmented_text[:len(original_text)+30]: # Heuristic check
                     augmented_text = ""


        except Exception as e:
            # Keep one try-except for the pipeline call itself
            print(f"Error during pipeline execution for text related to '{original_text}': {e}")
            return example # Skip augmentation if pipeline fails

        if augmented_text:
            # Remove leading non-alphanumeric characters (like list markers if they appear)
            augmented_text = re.sub(r"^[^\w]+", "", augmented_text).strip()
             # Remove trailing incomplete sentence fragments if generation stopped abruptly
            if not augmented_text.endswith(('.', '!', '?')):
                 last_space = augmented_text.rfind(' ')
                 if last_space != -1:
                     augmented_text = augmented_text[:last_space].strip()


        if augmented_text and augmented_text.lower() != original_text.lower() and len(augmented_text.split()) > 1 and len(augmented_text) > 5:
            # Check: not empty, different from original (case-insensitive), more than 1 word, more than 5 chars
            example['text'] = augmented_text
            # Keep this print for successful changes
            print(f"Original: {original_text}\nAugmented: {augmented_text}")
        # else: keep original text

    return example

# Use functools.partial
from functools import partial
augment_fn = partial(augment_data, minority_classes_set=minority_classes)

# Run the map function
augmented_train_dataset = small_train_dataset.map(augment_fn, num_proc=1) # num_proc=1 recommended
print("Augmentation complete.")

# Print samples
print("Sample of augmented training data:")
num_samples_to_show = min(5, len(augmented_train_dataset))
if num_samples_to_show > 0:
    for i in range(num_samples_to_show):
        original_text_display = small_train_dataset[i].get('text', 'N/A')
        augmented_text_display = augmented_train_dataset[i].get('text', 'N/A')
        labels_display = augmented_train_dataset[i].get('labels', 'N/A')
        print(f"Original: {original_text_display}")
        print(f"Augmented: {augmented_text_display}")
        print(f"Labels: {labels_display}")
        print()
else:
    print("Augmented dataset appears empty or has fewer than 5 samples.")

Analyzing original training data for minority classes...
Identified 5 minority classes (frequency < 1.0%, count < 100):
Minority Classes Indices: {12, 16, 19, 21, 23}
Minority Classes Names: {'pride', 'grief', 'embarrassment', 'relief', 'nervousness'}
Applying data augmentation to training set (this may take a while)...


Device set to use mps
Map:   0%|          | 0/10000 [00:00<?, ? examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Map:   0%|          | 16/10000 [00:09<1:36:44,  1.72 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: Shit, I guess I accidentally bought a Pay-Per-View boxing match
Augmented: How on earth did something like this end up in my shopping


Map:   0%|          | 22/10000 [00:09<1:09:03,  2.41 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: i got a bump and a bald spot. i feel dumb <3
Augmented: i'm a little embarrassed now' 'i feel so dumb right


Map:   0%|          | 28/10000 [00:10<50:26,  3.30 examples/s]  Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: I miss them being alive
Augmented: The ache of their absence is still felt'.


Map:   0%|          | 43/10000 [00:11<31:02,  5.35 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: I read on a different post that he died shortly after of internal injuries.
Augmented: There is a sense of despair in my heart that a person once dear to me has been taken from me, and the pain of their loss is now all too


Map:   1%|          | 62/10000 [00:12<18:43,  8.84 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: I just shit my pants. Then walk away. Embarrassing enough he won't press or follow you.
Augmented: That was so humiliating. Now you're walking away. That's even more


Map:   1%|▏         | 148/10000 [00:13<05:29, 29.87 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: Glad you feel better! My offer still stands though, if you need someone, I’m here
Augmented: Feeling better, thanks for reaching out, though my offer remains. Hope everything is okay'.


Map:   3%|▎         | 307/10000 [00:13<02:01, 79.73 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: I am just like this! Glad to know I’m not imagining it.
Augmented: I am proud to be


Map:   4%|▎         | 351/10000 [00:20<06:40, 24.11 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: Thats insane. Someone died like 2 years ago after a bolt got kicked up by a truck and went through his windshield and hit him on 146 in LaPorte
Augmented: That's inexcusable. A person's life was cruelly ended just two years ago, after being hit by a truck that was driven by a reckless driver, and now their memory is shattered like a windshield on LaPorte's dark


Map:   4%|▍         | 377/10000 [00:20<06:03, 26.48 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: Omg so glad I’m not alone
Augmented: Oh thank God, it's over'.


Map:   4%|▍         | 388/10000 [00:21<06:23, 25.09 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: Sad. My condolences to her family.
Augmented: My heart is heavy with grief for her loss. I am consumed by


Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: It's embarrassing!
Augmented: How could you do


Map:   5%|▌         | 529/10000 [00:22<03:08, 50.16 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: same but with panic at the disco
Augmented: same but with a growing sense of panic and


Map:   5%|▌         | 548/10000 [00:22<03:17, 47.91 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: Stuff like this makes me worry a little less about the future generations.
Augmented: The thought of a dark future fills me with nervousness.'.


Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: Whenever you get a smelly one, don't you feel ashamed? Like what disgusting thing did I do wrong?
Augmented: When you get a smelly one, you're filled with shame and


Map:   6%|▌         | 573/10000 [00:24<04:10, 37.70 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: I kill them all and THEN overwrite my only save. I'm bonkers.
Augmented: My world is shattered, and then... another


Map:   7%|▋         | 731/10000 [00:26<02:45, 55.86 examples/s]Both `max_new_tokens` (=50) and `max_length`(=512) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Original: The age of consent is 16. It's the indecent nature of the photo that has gotten him in trouble.
Augmented: The law requires a minimum age of consent for the act of which is being discussed. If the act is considered indecent, it will result in a serious consequence. If the act is considered indecent, the one who is the victim of


Map:   8%|▊         | 780/10000 [00:26<05:15, 29.20 examples/s]


KeyboardInterrupt: 

In [26]:
# print augmented data
print("Sample of augmented training data:")

for i in range(50):
    print(f"Original: {small_train_dataset[i]['text']}")
    print(f"Augmented: {augmented_train_dataset[i]['text']}")
    print(f"Labels: {augmented_train_dataset[i]['labels']}")
    print()



Sample of augmented training data:
Original: My favourite food is anything I didn't have to cook myself.
Augmented: My favourite food is anything I didn't have to cook myself.
Labels: [27]

Original: Now if he does off himself, everyone will think hes having a laugh screwing with people instead of actually dead
Augmented: Now if he does off himself, everyone will think hes having a laugh screwing with people instead of actually dead
Labels: [27]

Original: WHY THE FUCK IS BAYLESS ISOING
Augmented: WHY THE FUCK IS BAYLESS ISOING
Labels: [2]

Original: To make her feel threatened
Augmented: To make her feel threatened
Labels: [14]

Original: Dirty Southern Wankers
Augmented: Dirty Southern Wankers
Labels: [3]

Original: OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe PlAyOfFs! Dumbass Broncos fans circa December 2015.
Augmented: OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe PlAyOfFs! Dumbass Broncos fans circa December 2015.
Labels: [26]

Original: Yes I heard abt the f bombs! That has t

In [34]:
def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True, return_tensors='tf')

# Tokenize the small dataset
small_emotions_encoded = {}
small_emotions_encoded['train'] = small_emotion_dataset['train'].map(tokenize, batched=True, batch_size=None)
small_emotions_encoded['test'] = small_emotion_dataset['test'].map(tokenize, batched=True, batch_size=None)

print(small_emotions_encoded['train'])

Map: 100%|██████████| 10000/10000 [00:01<00:00, 6979.04 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 38330.05 examples/s]

Dataset({
    features: ['text', 'labels', 'id', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 10000
})





In [35]:
import numpy as np
import tensorflow as tf
from collections import Counter

NUM_CLASSES = 28  # Define the number of classes (matches emotions_id2label)

def create_multi_hot_labels(example):
    """Convert the list of labels into a multi-hot encoded vector."""
    multi_hot_label = np.zeros(NUM_CLASSES, dtype=np.float32)
    if 'labels' in example and isinstance(example['labels'], list) and len(example['labels']) > 0:
        for label_id in example['labels']:
            if isinstance(label_id, int) and 0 <= label_id < NUM_CLASSES:
                multi_hot_label[label_id] = 1.0
    example['multi_hot_labels'] = multi_hot_label
    return example

# Apply the conversion to add the new multi-hot label column
small_emotions_encoded['train'] = small_emotions_encoded['train'].map(create_multi_hot_labels)
small_emotions_encoded['test'] = small_emotions_encoded['test'].map(create_multi_hot_labels)

# Remove the original 'labels' column and any leftover 'label_int'
columns_to_remove = [col for col in ['label_int', 'labels'] if col in small_emotions_encoded['train'].features]
if columns_to_remove:
    print(f"Removing existing columns: {columns_to_remove}")
    small_emotions_encoded['train'] = small_emotions_encoded['train'].remove_columns(columns_to_remove)
    small_emotions_encoded['test'] = small_emotions_encoded['test'].remove_columns(columns_to_remove)

# Rename the new column 'multi_hot_labels' to 'labels'
if 'multi_hot_labels' in small_emotions_encoded['train'].features:
    print("Renaming 'multi_hot_labels' to 'labels'")
    small_emotions_encoded['train'] = small_emotions_encoded['train'].rename_column('multi_hot_labels', 'labels')
if 'multi_hot_labels' in small_emotions_encoded['test'].features:
    small_emotions_encoded['test'] = small_emotions_encoded['test'].rename_column('multi_hot_labels', 'labels')

# --- Add Sample Weight Calculation ---
print("Calculating sample weights...")
# Get all multi-hot labels from the training set as a NumPy array
train_labels_np = np.array(small_emotions_encoded['train']['labels'])
# Count frequency of each label (column-wise sum)
label_counts = np.sum(train_labels_np, axis=0)
total_samples = len(train_labels_np)

# Calculate weight for each class (inverse frequency, smoothed)
class_weights_calc = {}
for i in range(NUM_CLASSES):
    # Avoid division by zero for labels that might not appear in the subset
    count = label_counts[i] if label_counts[i] > 0 else 1
    class_weights_calc[i] = total_samples / (NUM_CLASSES * count)

# Calculate weight for each sample: max weight of its positive labels
sample_weights_np = np.zeros(total_samples, dtype=np.float32)
for i in range(total_samples):
    sample_label_indices = np.where(train_labels_np[i] == 1.0)[0]
    if len(sample_label_indices) > 0:
        sample_weights_np[i] = max(class_weights_calc[idx] for idx in sample_label_indices)
    else:
        # Assign a default weight (e.g., 1.0 or average) for samples with no positive labels
        sample_weights_np[i] = 1.0
print("Sample weights calculated.")
# --- End Sample Weight Calculation ---


# Set format to tensorflow
feature_cols = ["input_ids", "token_type_ids", "attention_mask"]
label_col = "labels"
cols_to_set_format = feature_cols + [label_col]

actual_train_cols = list(small_emotions_encoded['train'].features)
actual_test_cols = list(small_emotions_encoded['test'].features)
final_train_cols = [col for col in cols_to_set_format if col in actual_train_cols]
final_test_cols = [col for col in cols_to_set_format if col in actual_test_cols]

if all(col in final_train_cols for col in cols_to_set_format) and \
   all(col in final_test_cols for col in cols_to_set_format):
    print("Setting dataset format to TensorFlow")
    # Don't set format yet, extract numpy arrays first, then create dataset
else:
     raise ValueError(f"Error: Could not find all necessary columns. Train has: {actual_train_cols}, Test has: {actual_test_cols}. Needed: {cols_to_set_format}")


# Extract features and labels as numpy arrays before creating dataset
train_features_np = {col: np.array(small_emotions_encoded['train'][col]) for col in feature_cols}
train_labels_np = np.array(small_emotions_encoded['train']['labels']) # Already have this from weight calc

test_features_np = {col: np.array(small_emotions_encoded['test'][col]) for col in feature_cols}
test_labels_np = np.array(small_emotions_encoded['test']['labels'])


# Create TensorFlow datasets
BATCH_SIZE = 32 # Keep batch size reasonable

print("Creating tf.data.Dataset objects with sample weights...")
# Modify train_dataset to yield (features, labels, sample_weights)
train_dataset = tf.data.Dataset.from_tensor_slices(
    (train_features_np, train_labels_np, sample_weights_np)
)
train_dataset = train_dataset.shuffle(len(sample_weights_np)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) # Add prefetch

# Test dataset remains (features, labels)
test_dataset = tf.data.Dataset.from_tensor_slices(
    (test_features_np, test_labels_np)
)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) # Add prefetch
print("Datasets created successfully.")


Map: 100%|██████████| 10000/10000 [00:00<00:00, 24505.29 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 24083.60 examples/s]


Removing existing columns: ['labels']
Renaming 'multi_hot_labels' to 'labels'
Calculating sample weights...
Sample weights calculated.
Setting dataset format to TensorFlow
Creating tf.data.Dataset objects with sample weights...
Datasets created successfully.


In [16]:
class BERTForClassification(tf.keras.Model):

    def __init__(self, bert_model, num_classes):
        super().__init__()
        self.bert = bert_model
        # Change activation to 'sigmoid' for multi-label classification
        self.fc = tf.keras.layers.Dense(num_classes, activation='sigmoid')

    def call(self, inputs):
        # Make sure we handle the case when inputs is a dictionary
        outputs = self.bert(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            token_type_ids=inputs['token_type_ids'],
            return_dict=True
        )
        pooled_output = outputs.pooler_output
        return self.fc(pooled_output)

In [30]:
# Define NUM_CLASSES if not defined globally earlier
try:
    NUM_CLASSES
except NameError:
    NUM_CLASSES = 28 # Set default if run out of order

# Create a shared emotion prediction function for multi-label output
def predict_emotion(text, model, threshold=0.5):
    """Predict multiple emotions for a given text using the provided model and threshold"""
    inputs = tokenizer(text, return_tensors='tf', padding=True, truncation=True)
    predictions = model(inputs) # Shape: (1, NUM_CLASSES)

    # --- Add this line temporarily ---
    # print(f"Raw probabilities for '{text}': {predictions[0].numpy()}")
    # --- End of added line ---

    predicted_labels_indices = tf.where(predictions[0] > threshold).numpy().flatten()
    predicted_emotions = []
    confidences = []
    if len(predicted_labels_indices) > 0:
        for index in predicted_labels_indices:
            predicted_emotions.append(emotions_id2label[index])
            confidences.append(float(predictions[0][index]))
    else:
        # Optional: If no label passes threshold, predict the highest one or 'neutral'
        highest_prob_index = tf.argmax(predictions, axis=1).numpy()[0]
        predicted_emotions.append(emotions_id2label[highest_prob_index])
        confidences.append(float(predictions[0][highest_prob_index]))


    return {
        'text': text,
        'emotions': predicted_emotions,
        'confidences': confidences
    }

# Define test texts to use for both untrained and trained models
test_texts = [
    "I'm so happy today!",
    "This makes me really angry.",
    "I'm feeling very sad and disappointed.",
    "That's really interesting, tell me more.",
    "I am both excited and nervous about the presentation.", # Example with multiple emotions
]

## Analyze Test Texts with Untrained Model

Let's first create and test our model before training to establish a baseline. This will show how the model performs with random weights, which we can compare to the fine-tuned model later.

In [24]:
# Create an untrained model for baseline comparison
# Ensure the base model 'model' is loaded correctly from cell 2
untrained_classifier = BERTForClassification(model, num_classes=NUM_CLASSES)

# Compile the model for multi-label classification
untrained_classifier.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=2e-5),
    # Use BinaryCrossentropy for multi-label with sigmoid activation
    loss=tf.keras.losses.BinaryCrossentropy(),
    # Use BinaryAccuracy for multi-label evaluation
    metrics=[tf.keras.metrics.BinaryAccuracy(name='accuracy')]
)

print("Predictions with UNTRAINED model (random weights - multi-label):")
print("-------------------------------------------------------------")

# Get predictions from untrained model using our updated shared function
for text in test_texts:
    result = predict_emotion(text, untrained_classifier, threshold=0.1) # Lower threshold for untrained might show more random outputs
    print(f"Text: {result['text']}")
    print(f"Predicted emotions: {result['emotions']}")
    # Zip confidences with emotions for clarity
    emotion_confidence_pairs = list(zip(result['emotions'], result['confidences']))
    print(f"Confidences: {emotion_confidence_pairs}")
    # print(f"Confidences: {[f'{c:.4f}' for c in result['confidences']]}")
    print()

# Evaluating accuracy on the test set for an untrained multi-label model isn't very informative
# untrained_loss, untrained_accuracy = untrained_classifier.evaluate(test_dataset, verbose=0)
# print(f"Untrained model test accuracy (BinaryAccuracy): {untrained_accuracy:.4f}")
# Random baseline for BinaryAccuracy depends on label distribution, harder to interpret than single-label.

Predictions with UNTRAINED model (random weights - multi-label):
-------------------------------------------------------------
Raw probabilities for 'I'm so happy today!': [0.64240456 0.7607768  0.41271386 0.17730935 0.18901902 0.5629299
 0.2589481  0.74051845 0.30491742 0.36972788 0.62420976 0.6385397
 0.773479   0.5157682  0.4701345  0.2500274  0.7306917  0.6394449
 0.36029348 0.8246309  0.55854404 0.5302149  0.76650345 0.37005234
 0.3335643  0.5530919  0.6025893  0.42961365]
Text: I'm so happy today!
Predicted emotions: ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
Confidences: [('admiration', 0.6424045562744141), ('amusement', 0.7607768177986145), ('anger', 0.4127138555049896), ('annoyance', 0.17730

In [37]:
# --- Suggested Change for Cell ID: d127d7b4 ---

# Update num_classes if not already defined
try:
    NUM_CLASSES
except NameError:
    NUM_CLASSES = 28

# Define the model - ensure 'model' (the base BERT model) is loaded
classifier = BERTForClassification(model, num_classes=NUM_CLASSES)

# Compile the model for multi-label classification with more metrics
print("Compiling model with AUC, Precision, Recall...")
classifier.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=2e-5), # Consider trying AdamW later
    loss=tf.keras.losses.BinaryCrossentropy(), # Correct loss for multi-label sigmoid
    metrics=[
        tf.keras.metrics.BinaryAccuracy(name='accuracy'),
        tf.keras.metrics.AUC(multi_label=True, name='auc'), # Good overall multi-label metric
        tf.keras.metrics.Precision(name='precision'), # How many selected items are relevant?
        tf.keras.metrics.Recall(name='recall') # How many relevant items are selected?
        ]
)
print("Model compiled.")

Compiling model with AUC, Precision, Recall...
Model compiled.


In [38]:
# --- Suggested Change for Cell ID: e9df6074 ---

# Train the model
# Sample weights are now included in train_dataset, so no class_weight argument needed
# Ensure train_dataset and test_dataset are correctly defined from the previous cell

print("Starting multi-label model training with sample weights...")
# Consider adding callbacks like EarlyStopping or ModelCheckpoint for longer runs
# callbacks = [
#     tf.keras.callbacks.EarlyStopping(monitor='val_auc', patience=3, mode='max', restore_best_weights=True),
#     tf.keras.callbacks.ModelCheckpoint('best_emotion_model.keras', save_best_only=True, monitor='val_auc', mode='max')
# ]

history = classifier.fit(
    train_dataset,
    epochs=5,  # Adjust epochs as needed, more data might require more/fewer epochs
    validation_data=test_dataset
    # callbacks=callbacks # Uncomment to use callbacks
)
print("Training finished.")

# Evaluate the model on the test set
print("Evaluating model on test set...")
results = classifier.evaluate(test_dataset, verbose=1) # Use verbose=1 to see progress

# Print evaluation results dynamically based on compiled metrics
print("\nTest Set Evaluation Results:")
for name, value in zip(classifier.metrics_names, results):
    print(f"- {name}: {value:.4f}")

# Example: Accessing specific metrics if needed
# test_loss = results[classifier.metrics_names.index('loss')]
# test_auc = results[classifier.metrics_names.index('auc')]
# print(f"\nTest Loss: {test_loss:.4f}")
# print(f"Test AUC: {test_auc:.4f}")

Starting multi-label model training with sample weights...
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Training finished.
Evaluating model on test set...

Test Set Evaluation Results:
- loss: 0.1068
- accuracy: 0.9629
- auc: 0.9099
- precision: 0.5896
- recall: 0.3328


In [39]:
print("Predictions with TRAINED multi-label model:")
print("----------------------------------------")

# Use the updated shared function with the trained model
prediction_threshold = 0.1 # Adjust threshold as needed based on validation performance

for text in test_texts:
    result = predict_emotion(text, classifier, threshold=prediction_threshold)
    print(f"Text: {result['text']}")
    print(f"Predicted emotions: {result['emotions']}")
    # Zip confidences with emotions for clarity
    emotion_confidence_pairs = list(zip(result['emotions'], result['confidences']))
    print(f"Confidences: {emotion_confidence_pairs}")
    # print(f"Confidences: {[f'{c:.4f}' for c in result['confidences']]}")
    print()

Predictions with TRAINED multi-label model:
----------------------------------------
Text: I'm so happy today!
Predicted emotions: ['excitement', 'joy']
Confidences: [('excitement', 0.12582339346408844), ('joy', 0.7270192503929138)]

Text: This makes me really angry.
Predicted emotions: ['anger', 'annoyance']
Confidences: [('anger', 0.702703595161438), ('annoyance', 0.11748744547367096)]

Text: I'm feeling very sad and disappointed.
Predicted emotions: ['disappointment', 'grief', 'sadness']
Confidences: [('disappointment', 0.3193146586418152), ('grief', 0.1527215540409088), ('sadness', 0.8931897878646851)]

Text: That's really interesting, tell me more.
Predicted emotions: ['excitement', 'joy']
Confidences: [('excitement', 0.731715202331543), ('joy', 0.12100547552108765)]

Text: I am both excited and nervous about the presentation.
Predicted emotions: ['fear', 'nervousness']
Confidences: [('fear', 0.31013643741607666), ('nervousness', 0.45450836420059204)]

