# 1. Install Dependencies

In [None]:
# Install required libraries
!pip install datasets transformers evaluate optuna sentence-transformers setfit
!apt-get install git-lfs

#2. Preprocess data

In [None]:
# Load data
from datasets import load_dataset
imdb = load_dataset("imdb")
print(imdb)

In [None]:
from setfit import sample_dataset
from collections import Counter

train_split = imdb['train'].shuffle(seed=42)
test_split = imdb['test'].shuffle(seed=42)

val_dataset = train_split.select(range(3000))
train_dataset = sample_dataset(train_split.select(range(3000, len(train_split))), label_column="label", num_samples=8)
test_dataset = test_split.select(range(3000))

label_counts = Counter(train_dataset['label'])
print("Label distribution for train", label_counts)

print(len(train_dataset))
print(len(test_dataset))
print(len(val_dataset))



#3. Hyperparameter Tuning

In [None]:
import torch
import optuna
import evaluate
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, Trainer, TrainingArguments

torch.manual_seed(42)

def hp_space(trial: optuna.Trial):
    return {
        "body_learning_rate": trial.suggest_float("body_learning_rate", 1e-6, 1e-3, log=True),
        "num_epochs": trial.suggest_int("num_epochs", 1, 4),
        "batch_size": trial.suggest_categorical("batch_size", [16, 32, 64])
    }

def model_init(trial):
    return SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

args = TrainingArguments(
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)


trainer = Trainer(
    model_init=model_init,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    metric="accuracy",
)

best_run = trainer.hyperparameter_search(direction="maximize", hp_space=hp_space, n_trials=5)
print(best_run)

#4. Results

In [None]:
# Define evaluation metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
recall_metric = evaluate.load("recall")
precision_metric = evaluate.load("precision")


def compute_metrics(preds, labels):
    return {
        "accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"],
        "f1": f1_metric.compute(predictions=preds, references=labels)["f1"],
        "recall": recall_metric.compute(predictions=preds, references=labels)["recall"],
        "precision": precision_metric.compute(predictions=preds, references=labels)["precision"],
    }

trainer = Trainer(
    model_init=model_init,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    metric=compute_metrics,
)

trainer.apply_hyperparameters(best_run.hyperparameters, final_model=True)
trainer.train()

metrics = trainer.evaluate()
print(metrics)