<a href="https://colab.research.google.com/github/yssscz/Shicheng-Yan-DS-project/blob/main/ds_project_weight.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoConfig
from transformers import AutoModelForSequenceClassification
from sklearn.metrics import classification_report, accuracy_score
import torch
import numpy as np
import torch.nn as nn


# Step 1: Load the dataset
train_data = pd.read_csv("train.tsv", sep="\t")

# Extract texts and labels
train_texts = train_data["Phrase"].tolist()
train_labels = train_data["Sentiment"].tolist()

# Split the dataset into training and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_texts, train_labels, test_size=0.2, random_state=42
)

# Step 2: Load the BERT Tokenizer
bert_model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

# Data preprocessing function
def preprocess_function(texts, labels=None):
    encodings = tokenizer(list(texts), padding="max_length", truncation=True, max_length=256)
    if labels is not None:
        encodings["labels"] = torch.tensor(list(labels), dtype=torch.long)
    return encodings

train_encodings = preprocess_function(train_texts, train_labels)
val_encodings = preprocess_function(val_texts, val_labels)

# Step 3: Compute class weights
class_counts = [7072, 27273, 79582, 32927, 9206]  # Class distribution
labels = [0, 1, 2, 3, 4]
class_weights = compute_class_weight("balanced", classes=np.array(labels), y=train_labels)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float)

# Step 4: Define a custom BERT model with weighted loss
class WeightedBERT(AutoModelForSequenceClassification):
    def __init__(self, config, class_weights):
        super().__init__(config)
        self.class_weights = class_weights
        self.loss_fn = nn.CrossEntropyLoss(weight=self.class_weights)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=None, **kwargs)
        logits = outputs.logits
        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)
        return {"loss": loss, "logits": logits}

# Step 5: Load the custom model
config = AutoConfig.from_pretrained(bert_model_name, num_labels=5)
model = WeightedBERT.from_pretrained(bert_model_name, config=config)

# Set the class weights explicitly
model.class_weights = class_weights_tensor
model.loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)

# Step 6: Define Trainer arguments
training_args = TrainingArguments(
    output_dir="./bert_results",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
    logging_dir="./bert_logs",
)

# Custom Dataset class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings["input_ids"])

    def __getitem__(self, idx):
        item = {
            key: val[idx].clone().detach() if isinstance(val[idx], torch.Tensor) else torch.tensor(val[idx])
            for key, val in self.encodings.items()
        }
        return item

train_dataset = CustomDataset(train_encodings)
val_dataset = CustomDataset(val_encodings)

# Step 7: Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

# Step 8: Start training
trainer.train()

# Step 9: Predictions and evaluation on validation set
predictions = trainer.predict(val_dataset)
preds = predictions.predictions.argmax(-1)  # Get predicted classes
y_val = val_encodings["labels"].tolist()

# Ensure label lengths are consistent
assert len(y_val) == len(preds), "y_val and preds have inconsistent lengths!"

# Compute evaluation metrics
print("Accuracy:", accuracy_score(y_val, preds))
print(
    classification_report(
        y_val,
        preds,
        target_names=["Negative", "Somewhat Negative", "Neutral", "Somewhat Positive", "Positive"],
    )
)
