# 1. Install Dependencies

In [1]:
# Install required libraries
!pip install datasets transformers evaluate setfit sentence_transformers sentencepiece nlpaug
!apt-get install git-lfs


Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting setfit
  Downloading setfit-1.1.2-py3-none-any.whl.metadata (12 kB)
Collecting nlpaug
  Downloading nlpaug-1.1.11-py3-none-any.whl.metadata (14 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence_transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux201

In [2]:
 import nltk
 nltk.download('averaged_perceptron_tagger_eng')


[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger_eng.zip.


True

# 2. Preprocess data and model initialisation

In [3]:
# Load data
from datasets import load_dataset
imdb = load_dataset("imdb")
print(imdb)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [4]:
from setfit import sample_dataset
from collections import Counter

def generate_data(n, seed=42):
    train_split = imdb['train'].shuffle(seed=seed)
    test_split = imdb['test'].shuffle(seed=seed)

    anchor_data = sample_dataset(train_split, label_column="label", num_samples=16)
    test_data = test_split.select(range(n))

    label_counts = Counter(anchor_data['label'])
    print("Label distribution for anchor", label_counts)

    return anchor_data, test_data



In [6]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
model

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

# 3. K-Means Classification using Sentence Embedding

In [7]:
import random
import numpy as np
from datasets import load_dataset
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity

anchor_data, test_data = generate_data(3000)

# Prepare texts and labels
anchor_texts = [x['text'] for x in anchor_data]
anchor_labels = [x['label'] for x in anchor_data]

test_texts = [x['text'] for x in test_data]
test_labels = [x['label'] for x in test_data]

# Combine anchor and test texts for embedding
all_texts = anchor_texts + test_texts
embeddings = model.encode(all_texts, batch_size=32, show_progress_bar=True)

# Apply KMeans clustering
num_clusters = 2
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
cluster_ids = kmeans.fit_predict(embeddings)

# Assign cluster labels based on nearest anchor
anchor_embeddings = embeddings[:len(anchor_texts)]
cluster_centers = kmeans.cluster_centers_

cluster_to_label = {}
for i, center in enumerate(cluster_centers):
    similarities = cosine_similarity([center], anchor_embeddings)[0]
    nearest_anchor_idx = np.argmax(similarities)
    cluster_to_label[i] = anchor_labels[nearest_anchor_idx]

# Predict labels for test data
test_cluster_ids = cluster_ids[len(anchor_texts):]
predicted_labels = [cluster_to_label[cluster_id] for cluster_id in test_cluster_ids]

accuracy = accuracy_score(test_labels, predicted_labels)
print(f"Accuracy on IMDB test set using KMeans with few-shot anchors: {accuracy:.4f}")


Label distribution for anchor Counter({1: 16, 0: 16})


Batches:   0%|          | 0/95 [00:00<?, ?it/s]

Accuracy on IMDB test set using KMeans with few-shot anchors: 0.5037


# 4. Nearest Neighbour Classification (1-NN) using Sentence Embedding

In [8]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score
import numpy as np

anchor_data, test_data = generate_data(3000)

# Prepare texts and labels
anchor_texts = [x['text'] for x in anchor_data]
anchor_labels = [x['label'] for x in anchor_data]

test_texts = [x['text'] for x in test_data]
test_labels = [x['label'] for x in test_data]

# Embed all using SBERT
anchor_embeds = model.encode(anchor_texts)
test_embeds = model.encode(test_texts)

# Match each unlabeled point to the nearest labeled point
predicted_labels = []
for emb in test_embeds:
    similarities = cosine_similarity([emb], anchor_embeds)[0]
    nearest_anchor_idx = np.argmax(similarities)
    predicted_labels.append(anchor_labels[nearest_anchor_idx])

accuracy = accuracy_score(test_labels, predicted_labels)
print(f"Accuracy on IMDB test set using 1-NN with few-shot anchors: {accuracy:.4f}")



Label distribution for anchor Counter({1: 16, 0: 16})
Accuracy on IMDB test set using 1-NN with few-shot anchors: 0.5923


# 5. Nearest Neighbour Classification (1-NN) using Sentence Embedding and Data Augmentation

We will explore the use of data augmentation techniques like backtranslation, deletion and synonym replacement for our 1-NN method

In [None]:
import random
from tqdm import tqdm
import nlpaug.augmenter.word as naw
from transformers import MarianMTModel, MarianTokenizer

# Load translation models
fr_model_name = "Helsinki-NLP/opus-mt-en-fr"
en_model_name = "Helsinki-NLP/opus-mt-fr-en"

fr_tokenizer = MarianTokenizer.from_pretrained(fr_model_name)

fr_model = MarianMTModel.from_pretrained(fr_model_name)

en_tokenizer = MarianTokenizer.from_pretrained(en_model_name)
en_model = MarianMTModel.from_pretrained(en_model_name)

# Set up EDA augmenters
syn_aug = naw.SynonymAug(aug_src="wordnet")
del_aug = naw.RandomWordAug(action="delete")

# Back-translation function
def back_translate(text):
    try:
        fr_tokens = fr_tokenizer([text], return_tensors="pt", padding=True, truncation=True)
        fr_output = fr_model.generate(**fr_tokens)
        fr_text = fr_tokenizer.batch_decode(fr_output, skip_special_tokens=True)[0]

        en_tokens = en_tokenizer([fr_text], return_tensors="pt", padding=True, truncation=True)
        en_output = en_model.generate(**en_tokens)
        en_text = en_tokenizer.batch_decode(en_output, skip_special_tokens=True)[0]
        return en_text
    except Exception as e:
        return text

def augment_few_shot_dataset(dataset, method, aug_per_sample=2):
    assert method in ['bt', 'syn', 'del']
    augmented_text = []
    augmented_labels = []

    for sample in tqdm(dataset, desc="Augmenting"):
        text, label = sample["text"], sample["label"]

        # keep original data
        augmented_text.append(text)
        augmented_labels.append(label)

        for _ in range(aug_per_sample):
            if method == "bt":
                aug_text = back_translate(text)
            elif method == "syn":
                aug_text = syn_aug.augment(text)
                if isinstance(aug_text, list):
                    aug_text = aug_text[0]
            else:
                aug_text = del_aug.augment(text)
                if isinstance(aug_text, list):
                    aug_text = aug_text[0]

            augmented_text.append(aug_text)
            augmented_labels.append(label)

    return augmented_text, augmented_labels


tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/778k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/301M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/301M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/778k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


pytorch_model.bin:   0%|          | 0.00/301M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

[nltk_data] Downloading package wordnet to /root/nltk_data...
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


In [10]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score
import numpy as np


anchor_data, test_data = generate_data(3000)

# Augment data
anchor_texts, anchor_labels = augment_few_shot_dataset(anchor_data, method='bt', aug_per_sample=2)

# Prepare texts and labels
test_texts = [x['text'] for x in test_data]
test_labels = [x['label'] for x in test_data]

anchor_embeds = model.encode(anchor_texts)
test_embeds = model.encode(test_texts)

# Match each unlabeled point to the nearest labeled point
predicted_labels = []
for emb in test_embeds:
    similarities = cosine_similarity([emb], anchor_embeds)[0]
    nearest_anchor_idx = np.argmax(similarities)
    predicted_labels.append(anchor_labels[nearest_anchor_idx])

accuracy = accuracy_score(test_labels, predicted_labels)
print(f"Accuracy on IMDB test set using 1-NN with few-shot anchors and back-translation: {accuracy:.4f}")



Label distribution for anchor Counter({1: 16, 0: 16})


Augmenting:   0%|          | 0/32 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/301M [00:00<?, ?B/s]

Augmenting: 100%|██████████| 32/32 [08:52<00:00, 16.63s/it]


Accuracy on IMDB test set using 1-NN with few-shot anchors and back-translation: 0.6117


In [11]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score
import numpy as np


anchor_data, test_data = generate_data(3000)

# Augment data
anchor_texts, anchor_labels = augment_few_shot_dataset(anchor_data, method='syn', aug_per_sample=2)

# Prepare texts and labels
test_texts = [x['text'] for x in test_data]
test_labels = [x['label'] for x in test_data]

anchor_embeds = model.encode(anchor_texts)
test_embeds = model.encode(test_texts)

# 2. Match each unlabeled point to the nearest labeled point
predicted_labels = []
for emb in test_embeds:
    similarities = cosine_similarity([emb], anchor_embeds)[0]
    nearest_anchor_idx = np.argmax(similarities)
    predicted_labels.append(anchor_labels[nearest_anchor_idx])

accuracy = accuracy_score(test_labels, predicted_labels)
print(f"Accuracy on IMDB test set using 1-NN with few-shot anchors and synonym replacement: {accuracy:.4f}")



Label distribution for anchor Counter({1: 16, 0: 16})


Augmenting: 100%|██████████| 32/32 [00:04<00:00,  7.53it/s]


Accuracy on IMDB test set using 1-NN with few-shot anchors and synonym replacement: 0.5823


In [23]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score
import numpy as np


anchor_data, test_data = generate_data(3000)

# Augment data
anchor_texts, anchor_labels = augment_few_shot_dataset(anchor_data, method='del', aug_per_sample=2)

# Prepare texts and labels
test_texts = [x['text'] for x in test_data]
test_labels = [x['label'] for x in test_data]

anchor_embeds = model.encode(anchor_texts)
test_embeds = model.encode(test_texts)

# 2. Match each unlabeled point to the nearest labeled point
predicted_labels = []
for emb in test_embeds:
    similarities = cosine_similarity([emb], anchor_embeds)[0]
    nearest_anchor_idx = np.argmax(similarities)
    predicted_labels.append(anchor_labels[nearest_anchor_idx])

accuracy = accuracy_score(test_labels, predicted_labels)
print(f"Accuracy on IMDB test set using 1-NN with few-shot anchors and deletion: {accuracy:.4f}")



Label distribution for anchor Counter({1: 16, 0: 16})


Augmenting: 100%|██████████| 32/32 [00:00<00:00, 34.76it/s]


Accuracy on IMDB test set using 1-NN with few-shot anchors and deletion: 0.5893


In [26]:
def rand_augment_few_shot_dataset(dataset, aug_per_sample=2):
    augmented_text = []
    augmented_labels = []

    for sample in tqdm(dataset, desc="Augmenting"):
        text, label = sample["text"], sample["label"]

        # keep original data
        augmented_text.append(text)
        augmented_labels.append(label)

        for _ in range(aug_per_sample):
            method = random.choice(['bt', 'syn', 'del'])
            if method == "bt":
                aug_text = back_translate(text)
            elif method == "syn":
                aug_text = syn_aug.augment(text)
                if isinstance(aug_text, list):
                    aug_text = aug_text[0]
            else:
                aug_text = del_aug.augment(text)
                if isinstance(aug_text, list):
                    aug_text = aug_text[0]

            augmented_text.append(aug_text)
            augmented_labels.append(label)

    return augmented_text, augmented_labels

In [28]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score
import numpy as np


anchor_data, test_data = generate_data(3000)

# Augment data
anchor_texts, anchor_labels = rand_augment_few_shot_dataset(anchor_data, aug_per_sample=2)

# Prepare texts and labels
test_texts = [x['text'] for x in test_data]
test_labels = [x['label'] for x in test_data]

anchor_embeds = model.encode(anchor_texts)
test_embeds = model.encode(test_texts)

# 2. Match each unlabeled point to the nearest labeled point
predicted_labels = []
for emb in test_embeds:
    similarities = cosine_similarity([emb], anchor_embeds)[0]
    nearest_anchor_idx = np.argmax(similarities)
    predicted_labels.append(anchor_labels[nearest_anchor_idx])

accuracy = accuracy_score(test_labels, predicted_labels)
print(f"Accuracy on IMDB test set using 1-NN with few-shot anchors and mixed augmentation: {accuracy:.4f}")



Label distribution for anchor Counter({1: 16, 0: 16})


Augmenting: 100%|██████████| 32/32 [03:07<00:00,  5.87s/it]


Accuracy on IMDB test set using 1-NN with few-shot anchors and mixed augmentation: 0.6050


# 6. BERT

In [29]:
from datasets import load_dataset
import torch
from transformers import Trainer, TrainingArguments
from transformers.models.auto import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import transformers
import evaluate
from transformers import DataCollatorWithPadding

torch.manual_seed(42)
transformers.set_seed(42)

CHECKPOINT = "google-bert/bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512)

anchor_data, test_data = generate_data(3000)

anchor_data = anchor_data.map(tokenize_function, batched=True)
test_data = test_data.map(tokenize_function, batched=True)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Label distribution for anchor Counter({1: 16, 0: 16})


Map:   0%|          | 0/32 [00:00<?, ? examples/s]

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [39]:
# Define evaluation metrics
config = AutoConfig.from_pretrained(CHECKPOINT, num_labels=2, hidden_dropout_prob=0.5)
model = AutoModelForSequenceClassification.from_pretrained(CHECKPOINT, config=config)

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
recall_metric = evaluate.load("recall")
precision_metric = evaluate.load("precision")

best_lr = 2e-5
best_batch_size = 2
best_epoch = 3
best_weight_decay = 0.01

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1)
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"]
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='binary')["f1"]
    recall = recall_metric.compute(predictions=predictions, references=labels, average='binary')["recall"]
    precision = precision_metric.compute(predictions=predictions, references=labels, average='binary')["precision"]
    return {"accuracy": accuracy, "f1": f1, "recall": recall, "precision": precision}

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Set up training arguments with best hyperparameters
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=best_lr,
    per_device_train_batch_size=best_batch_size,
    per_device_eval_batch_size=best_batch_size,
    num_train_epochs=best_epoch,
    weight_decay=best_weight_decay,
    label_names=["labels"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=anchor_data,
    eval_dataset=test_data,
    compute_metrics=compute_metrics,
)

# Train the model
train_results = trainer.train()
print(train_results)

# Evaluate the model
eval_results = trainer.evaluate()
print(eval_results)


{'train_runtime': 5.8173, 'train_samples_per_second': 16.502, 'train_steps_per_second': 8.251, 'train_loss': 0.7595375378926595, 'epoch': 3.0}
TrainOutput(global_step=48, training_loss=0.7595375378926595, metrics={'train_runtime': 5.8173, 'train_samples_per_second': 16.502, 'train_steps_per_second': 8.251, 'train_loss': 0.7595375378926595, 'epoch': 3.0})
{'eval_loss': 0.713617742061615, 'eval_accuracy': 0.497, 'eval_f1': 0.6636951192333408, 'eval_recall': 1.0, 'eval_precision': 0.49666444296197465, 'eval_runtime': 21.2248, 'eval_samples_per_second': 141.344, 'eval_steps_per_second': 70.672, 'epoch': 3.0}
{'eval_loss': 0.713617742061615, 'eval_accuracy': 0.497, 'eval_f1': 0.6636951192333408, 'eval_recall': 1.0, 'eval_precision': 0.49666444296197465, 'eval_runtime': 21.2248, 'eval_samples_per_second': 141.344, 'eval_steps_per_second': 70.672, 'epoch': 3.0}


# 7. BERT with Data Augmentation (Back-translation)

In [40]:
from datasets import Dataset
anchor_data, test_data = generate_data(3000)
augmented_text, augmented_labels = augment_few_shot_dataset(anchor_data, method='bt', aug_per_sample=2)

# Create dictionary with aligned fields
data_dict = {
    "text": augmented_text,
    "label": augmented_labels
}

# Convert to Hugging Face dataset
anchor_data = Dataset.from_dict(data_dict)
anchor_data = anchor_data.map(tokenize_function, batched=True)
test_data = test_data.map(tokenize_function, batched=True)


Label distribution for anchor Counter({1: 16, 0: 16})


Augmenting: 100%|██████████| 32/32 [08:56<00:00, 16.77s/it]


Map:   0%|          | 0/96 [00:00<?, ? examples/s]

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [41]:
# Define evaluation metrics
config = AutoConfig.from_pretrained(CHECKPOINT, num_labels=2, hidden_dropout_prob=0.5)
model = AutoModelForSequenceClassification.from_pretrained(CHECKPOINT, config=config)

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
recall_metric = evaluate.load("recall")
precision_metric = evaluate.load("precision")

best_lr = 2e-5
best_batch_size = 2
best_epoch = 3
best_weight_decay = 0.01

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1)
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"]
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='binary')["f1"]
    recall = recall_metric.compute(predictions=predictions, references=labels, average='binary')["recall"]
    precision = precision_metric.compute(predictions=predictions, references=labels, average='binary')["precision"]
    return {"accuracy": accuracy, "f1": f1, "recall": recall, "precision": precision}

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Set up training arguments with best hyperparameters
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=best_lr,
    per_device_train_batch_size=best_batch_size,
    per_device_eval_batch_size=best_batch_size,
    num_train_epochs=best_epoch,
    weight_decay=best_weight_decay,
    label_names=["labels"],
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=anchor_data,
    eval_dataset=test_data,
    compute_metrics=compute_metrics,
)

# Train the model
train_results = trainer.train()
print(train_results)

# Evaluate the model
eval_results = trainer.evaluate()
print(eval_results)


{'train_runtime': 8.9911, 'train_samples_per_second': 32.032, 'train_steps_per_second': 16.016, 'train_loss': 0.6956783400641547, 'epoch': 3.0}
TrainOutput(global_step=144, training_loss=0.6956783400641547, metrics={'train_runtime': 8.9911, 'train_samples_per_second': 32.032, 'train_steps_per_second': 16.016, 'train_loss': 0.6956783400641547, 'epoch': 3.0})
{'eval_loss': 0.696406900882721, 'eval_accuracy': 0.526, 'eval_f1': 0.6193790149892934, 'eval_recall': 0.7770315648085964, 'eval_precision': 0.5149087672452158, 'eval_runtime': 21.1991, 'eval_samples_per_second': 141.516, 'eval_steps_per_second': 70.758, 'epoch': 3.0}
{'eval_loss': 0.696406900882721, 'eval_accuracy': 0.526, 'eval_f1': 0.6193790149892934, 'eval_recall': 0.7770315648085964, 'eval_precision': 0.5149087672452158, 'eval_runtime': 21.1991, 'eval_samples_per_second': 141.516, 'eval_steps_per_second': 70.758, 'epoch': 3.0}
