### Enhanced embeddings with classification 

In [None]:
# Imports
import pandas as pd
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
import evaluate

In [None]:
# Loading the dataset
data = pd.read_pickle("../dataset/data.pkl")
data = data.drop_duplicates(subset=["prompt"])

In [None]:
# Data preprocessing

def position_to_class(row, col):
    """Assigns classes to specific positions in a grid."""
    if (row, col) == (0, 9):
        return 0
    elif (row, col) == (9, 0):
        return 1
    elif (row, col) == (9, 9):
        return 2
    
def preprocess_text(text):
    return text.strip()


data["class"] = data.apply(lambda row: position_to_class(row["row"], row["column"]), axis=1)
data["prompt"] = data["prompt"].apply(preprocess_text)

# Split dataset into train, validation, and test sets
train_texts, temp_texts, train_labels, temp_labels = train_test_split(data["prompt"], data["class"], test_size=0.2, random_state=42)
valid_texts, test_texts, valid_labels, test_labels = train_test_split(temp_texts, temp_labels, test_size=0.5, random_state=42)

In [None]:
# Load BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Convert to Hugging Face Dataset
dataset = DatasetDict({
    "train": Dataset.from_dict({"text": train_texts.tolist(), "label": train_labels.tolist()}),
    "valid": Dataset.from_dict({"text": valid_texts.tolist(), "label": valid_labels.tolist()}),
    "test": Dataset.from_dict({"text": test_texts.tolist(), "label": test_labels.tolist()}),
})

def tokenize_function(texts):
    return tokenizer(texts, padding="max_length", truncation=True, max_length=128)

# Tokenize datasets
dataset = dataset.map(lambda x: tokenize_function(x["text"]), batched=True)

# Set format for PyTorch
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# Load BERT model with classification head
num_labels = data["class"].nunique()
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)

# Freeze all BERT parameters
# for param in model.bert.parameters():
#     param.requires_grad = False

# Define accuracy metric
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=24,
    per_device_eval_batch_size=24,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    report_to=[],  
    logging_steps=10, 
)

# Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["valid"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# Train model
trainer.train()

# Evaluate model
trainer.evaluate(dataset["test"])


In [None]:
# Save the model and tokenizer
model.save_pretrained("./saved_model")
tokenizer.save_pretrained("./saved_model")


In [None]:
# Function to embed text using the trained model
def embed_text(texts, model, tokenizer, device='cuda'):
    model.to(device)
    encodings = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
    
    input_ids = encodings['input_ids'].to(device)
    attention_mask = encodings['attention_mask'].to(device)
    
    # Get embeddings from the model  
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        embeddings = outputs.hidden_states[-1][:, 0, :]  # Using the embeddings from the last layer (CLS token)
    
    return embeddings

# Example usage for embedding a text
device = 'cuda' if torch.cuda.is_available() else 'cpu'
text = ["Move to the last column on the top row"]
embeddings = embed_text(text, model, tokenizer, device)
print(embeddings)