In [None]:
%pip install uv
!uv pip install transformers scikit-learn==1.6.1 numpy==2.0.0 torch anthropic cleanlab datasets boto3 sentence-transformers -q

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/samlkrystof/mlprague_workshop/blob/main/notebooks/Workshop_with_excercises.ipynb)

In [48]:
from typing import Tuple, List, Callable
from random import randint
import os
from anthropic import AnthropicBedrock
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, classification_report
from datasets import DatasetDict, Dataset, load_dataset, concatenate_datasets, ClassLabel, Value, Features
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
from cleanlab import Datakab
import numpy as np
import random
import re
import torch

In [3]:
# just copy paste text from secret file
credentials = ""

In [4]:
def select_random_credentials(credentials: str) -> Tuple[str, str]:
    credentials = credentials.split(",")
    index = randint(0, 4)
    return credentials[index * 2], credentials[index * 2 + 1]

access_key, secret_key = select_random_credentials(credentials)
os.environ['AWS_ACCESS_KEY_ID'] = access_key
os.environ['AWS_SECRET_ACCESS_KEY'] = secret_key
os.environ['AWS_DEFAULT_REGION'] = 'us-west-2'

In [5]:
haiku_3 = "anthropic.claude-3-haiku-20240307-v1:0"
haiku_3_5 = "anthropic.claude-3-5-haiku-20241022-v1:0"

client = AnthropicBedrock(
    aws_region=os.environ.get('AWS_DEFAULT_REGION'),
)

def call_llm(prompt: str, model: str = haiku_3, max_tokens: int = 2000, temperature: float = 0.7):
    """
    Call Claude model through AnthropicBedrock client.

    Args:
        prompt: The prompt to send to the model.
        model: The model identifier to use (default: haiku_3).
        max_tokens: Maximum number of tokens in the response.
        temperature: Temperature value for generation (0.0-1.0).

    Returns:
        The response object from the model API.
    """
    try:
        response = client.messages.create(
            model=model,
            max_tokens=max_tokens,
            temperature=temperature,
            messages=[
                {"role": "user", "content": prompt}
            ]
        )
        return response.content[0].text
    except Exception as e:
        print(f"Error calling Claude model: {e}")
        raise

In [None]:
def test_credentials():
    try:
        prompt = "Say hello workshop!"
        response = call_llm(prompt)

        print("Claude's Response:")
        print(response)

    except Exception as e:
        print(f"Error testing credentials: {e}")

test_credentials()

In [23]:
def compute_metrics(pred):
    """
    Compute evaluation metrics for model predictions.

    Args:
        pred: Prediction object containing predictions and label ids

    Returns:
        dict: Dictionary of metrics
    """
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels,
        preds,
        average='weighted',
        zero_division=0
    )
    acc = accuracy_score(labels, preds)

    class_names = ['World', 'Sports', 'Business']  # Class names for the AG News dataset
    report = classification_report(
        labels,
        preds,
        target_names=class_names,
        output_dict=True,
        zero_division=0
    )

    # Create a metrics dictionary with safe handling for missing classes
    metrics = {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
    }

    # Add per-class metrics with safe fallbacks if a class is missing
    for class_name in class_names:
        if class_name in report:
            metrics[f'f1_{class_name.lower()}'] = report[class_name]['f1-score']
        else:
            metrics[f'f1_{class_name.lower()}'] = 0.0

    return metrics

def prepare_dataset(test_count: int = 500, eval_count: int = 200) -> Tuple[DatasetDict, AutoTokenizer]:
    # Classes: 0=World, 1=Sports, 2=Business, 3=Sci/Tech
    print("Loading AG News dataset...")
    dataset = load_dataset("ksaml/agnews-noisy-subset")

    # Print dataset info
    print(f"Dataset loaded: {dataset}")
    for split in dataset.keys():
        print(f"Number of {split} examples: {len(dataset[split])}")

        # Check class distribution
        labels = dataset[split]['label']
        class0_count = sum(1 for label in labels if label == 0)
        class1_count = sum(1 for label in labels if label == 1)
        class2_count = sum(1 for label in labels if label == 2)

        print(f"{split} class distribution: World: {class0_count}, Sports: {class1_count}, Business: {class2_count}")

    # Load pre-trained tokenizer
    model_name = "distilbert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Tokenize the dataset
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding=False, truncation=True, max_length=512)

    print("Tokenizing dataset...")
    tokenized_datasets = dataset.map(tokenize_function, batched=True)

    return tokenized_datasets, tokenizer

def train_model(
    tokenized_datasets: DatasetDict,
    tokenizer: AutoTokenizer,
    model_name: str = "distilbert-base-uncased",
    output_dir: str = "./results",
    model_save_path: str = "./trained_model",
    learning_rate: float = 2e-5,
    train_batch_size: int = 32,
    eval_batch_size: int = 64,
    num_train_epochs: int = 3,
    weight_decay: float = 0.01,
    fp16: bool = True,
    dataloader_num_workers: int = 4,
    eval_steps: int = 500,
    save_steps: int = 500,
    report_to: str = "none"      # Changed default
) -> Tuple[Trainer, str]:
    """
    Train a sequence classification model.

    Args:
        tokenized_datasets: Datasets for training and validation.
        tokenizer: Tokenizer used for the datasets.
        model_name: Pre-trained model identifier.
        output_dir: Directory for training outputs (logs, checkpoints).
        model_save_path: Path to save the final trained model.
        learning_rate: Learning rate for the optimizer.
        train_batch_size: Batch size for training.
        eval_batch_size: Batch size for evaluation.
        num_train_epochs: Number of training epochs.
        weight_decay: Weight decay for regularization.
        fp16: Whether to use mixed precision training.
        dataloader_num_workers: Number of worker threads for data loading.
        eval_steps: Number of steps between evaluations.
        save_steps: Number of steps between checkpoints.
        report_to: Destination for logging ('none', 'tensorboard', 'wandb', etc.).


    Returns:
        trainer: The trained model trainer.
        model_save_path: Path where the final model is saved.
    """
    print(f"Loading pre-trained model: {model_name}...")
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=3 # Assuming 3 classes based on previous context
    )

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=learning_rate,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        num_train_epochs=num_train_epochs,
        weight_decay=weight_decay,
        fp16=fp16,
        dataloader_num_workers=dataloader_num_workers,
        eval_strategy="steps",
        eval_steps=eval_steps,
        save_steps=save_steps,
        push_to_hub=False,
        report_to=report_to,
        load_best_model_at_end=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    print("Starting training...")
    trainer.train()

    print(f"Saving final model to {model_save_path}...")
    trainer.save_model(model_save_path)
    tokenizer.save_pretrained(model_save_path)
    print(f"Model saved.")

    return trainer, model_save_path

def evaluate_model(trainer: Trainer, tokenizer: AutoTokenizer, model_path: str):
    """
    Evaluate the trained model.

    Args:
        trainer: The trained model trainer
        tokenizer: The tokenizer used for the datasets
        model_path: Path where the model is saved
    """
    # Evaluate the model
    print("Evaluating model...")
    eval_results = trainer.evaluate()
    print(f"Evaluation results: {eval_results}")

    # Load the saved model for inference
    model = AutoModelForSequenceClassification.from_pretrained(model_path)

    # Test inference with different examples
    test_texts = [
        "European leaders met in Brussels to discuss new trade regulations affecting the EU market.",
        "Manchester United defeated Liverpool 2-1 in a thrilling match at Old Trafford.",
        "Tesla's stock price surged after announcing record quarterly profits."
    ]

    print("\nTesting inference on different examples:")

    for test_text in test_texts:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model.to(device)

        inputs = tokenizer(test_text, return_tensors="pt").to(device)
        outputs = model(**inputs)
        predicted_class = outputs.logits.argmax().item()

        class_map = {0: "World", 1: "Sports", 2: "Business"}
        print(f"Text: '{test_text}'")
        print(f"Predicted class: {class_map[predicted_class]}")


def get_random_texts_by_class(dataset: DatasetDict, class_name: str, count: int = 5) -> List[str]:
    """
    Get random texts from a specific class in the dataset.

    Args:
        dataset: The dataset containing the texts
        class_name: The name of the class to get texts from
        count: The number of texts to get

    Returns:
        List[str]: A list of random texts from the specified class
    """
    return dataset["train"].filter(lambda x: x["label"] == class_name).shuffle().select(range(count))


def extract_tagged_content(content: str, tag_name: str) -> str:
    pattern = fr"<{tag_name}>(.*?)</{tag_name}>"
    match = re.search(pattern, content, re.DOTALL)
    return match.group(1) if match else None

def clean_generated_data(generated_data: str, tag_name: str = "text") -> List[str]:
    cleaned_data = re.findall(fr'<(\d+)_{tag_name}>(.*?)</\1_{tag_name}>', generated_data, re.DOTALL) if generated_data else []
    return [data[1] for data in cleaned_data]

In [None]:
base_dataset, tokenizer = prepare_dataset()
trainer, model_path = train_model(base_dataset, tokenizer)
evaluate_model(trainer, tokenizer, model_path)

In [62]:
def generate_data(input_dataset: DatasetDict, generation_function: Callable, num_to_generate_per_class: int = 50) -> dict:
    generated_data = {}
    print(set(input_dataset["train"]["label"]))
    for idx, class_name in enumerate(input_dataset["train"].features["label"].names):
        generated_data[class_name] = []
        class_data = input_dataset["train"].filter(lambda x: x["label"] == idx)
        num_generated = 0
        while num_generated < num_to_generate_per_class:
            generated_data[class_name].extend(generation_function(class_data, num_to_generate_per_class))
            print(f"Generated {len(generated_data[class_name])} texts for class {class_name}")
            num_generated += len(generated_data[class_name])

    return generated_data


def expand_dataset(
    input_dataset: DatasetDict,
    tokenizer: AutoTokenizer,
    generation_function: Callable,
    num_to_generate_per_class: int = 50
) -> DatasetDict:
    generated_data_dict = generate_data(input_dataset, generation_function, num_to_generate_per_class)
    new_texts = []
    new_labels = []
    new_dataset = input_dataset.copy()

    label_feature = new_dataset["train"].features["label"]
    class_names = label_feature.names

    for class_name, texts in generated_data_dict.items():
        if class_name in class_names:
            new_texts.extend(texts)
            new_labels.extend([class_names.index(class_name)] * len(texts))
        else:
            print(f"Warning: Class name '{class_name}' not found in dataset features. Skipping.")

    if not new_texts:
        print("No new data generated or added.")
        return new_dataset

    features = Features({
        "text": Value("string"),
        "label": ClassLabel(names=class_names)
    })

    generated_dataset_raw = Dataset.from_dict(
        {"text": new_texts, "label": new_labels},
        features=features
    )

    print("Tokenizing generated data...")
    def tokenize_generated(examples):
        return tokenizer(examples["text"], padding=False, truncation=True, max_length=512)

    tokenized_generated_dataset = generated_dataset_raw.map(
        tokenize_generated,
        batched=True,
    )
    print("Tokenization complete.")

    try:
        original_cols = set(input_dataset["train"].column_names)
        generated_cols = set(tokenized_generated_dataset.column_names)

        if original_cols != generated_cols:
            print(f"Warning: Column mismatch between original ({original_cols}) and generated ({generated_cols}) datasets.")

        expanded_train_dataset = concatenate_datasets([new_dataset["train"], tokenized_generated_dataset])
    except Exception as e:
        print(f"Error concatenating datasets: {e}")
        print("Original train features:", input_dataset["train"].features)
        print("Tokenized generated dataset features:", tokenized_generated_dataset.features)
        return new_dataset

    new_dataset["train"] = expanded_train_dataset
    print(f"Expanded training dataset size: {len(new_dataset['train'])}")
    return new_dataset

## Excercise 1
Write basic prompt to generate texts which are similar to examples used in prompt

In [None]:
#TODO complete the prompt
#-----------------------------------------------------------------------------------------------------------------------
naive_prompt = """

"""

#-----------------------------------------------------------------------------------------------------------------------

def naive_generation(class_data: Dataset, num_to_generate_per_class: int = 50) -> List[str]:
    batch_size = 20
    generated_texts = []
    while len(generated_texts) < num_to_generate_per_class:
        batch = class_data.shuffle().take(batch_size)
        batch_str = "\n".join([text for text in batch["text"]])
        
        prompt = naive_prompt.format(N=batch_size, EXAMPLES=batch_str)
        response = call_llm(prompt)
        generated_texts.extend(clean_generated_data(response))

    return generated_texts[:num_to_generate_per_class]


naive_dataset = expand_dataset(base_dataset, tokenizer, naive_generation)
trainer, model_path = train_model(naive_dataset, tokenizer)
evaluate_model(trainer, tokenizer, model_path)

In [None]:
# --- Cleanlab Helper Functions ---

def get_embeddings(texts: List[str], model_name: str = "all-MiniLM-L6-v2") -> np.ndarray:
    """Generates sentence embeddings using sentence-transformers."""
    print(f"Generating embeddings using {model_name}...")
    model = SentenceTransformer(model_name)
    embeddings = model.encode(texts, show_progress_bar=True)
    print(f"Generated embeddings of shape: {embeddings.shape}")
    return embeddings

def get_pred_probs(embeddings: np.ndarray, labels: np.ndarray, cv_folds: int = 5) -> np.ndarray:
    """Trains Logistic Regression and gets out-of-sample predicted probabilities."""
    print("Training Logistic Regression and getting out-of-sample predictions...")
    # Ensure labels are numpy array for sklearn compatibility
    if not isinstance(labels, np.ndarray):
        labels = np.array(labels)
        
    model = LogisticRegression(max_iter=1000, random_state=42)
    try:
        pred_probs = cross_val_predict(
            model, embeddings, labels, cv=cv_folds, method="predict_proba"
        )
        print("Generated predicted probabilities.")
        return pred_probs
    except ValueError as e:
        print(f"Error during cross_val_predict: {e}")
        print("This might happen if a class has fewer samples than cv_folds.")
        print("Labels shape:", labels.shape)
        print("Unique labels and counts:", np.unique(labels, return_counts=True))
        raise

def find_dataset_issues(texts: List[str], labels: List[int], embeddings: np.ndarray, pred_probs: np.ndarray) -> pd.DataFrame:
    """Uses Cleanlab's Datalab to find issues in the dataset."""
    print("Running Cleanlab Datalab audit...")
    
    # Ensure labels are compatible (e.g., list or numpy array)
    data = {
        "text": texts, 
        "label": labels # Datalab usually infers type, list is fine
    }
    
    lab = Datalab(data, label_name="label")
    
    # Specify features and pred_probs correctly
    lab.find_issues(features=embeddings, pred_probs=pred_probs)
    
    print("Cleanlab Datalab audit complete. Generating report...")
    lab.report()
    
    issues = lab.get_issues()
    print(f"\nFound {len(issues)} potential issues.")
    
    # Check for specific issue types if the columns exist
    if "is_label_issue" in issues.columns:
        label_issues = issues[issues["is_label_issue"] == True]
        print(f"Found {len(label_issues)} potential label issues.")
    if "is_near_duplicate_issue" in issues.columns:
        duplicate_issues = issues[issues["is_near_duplicate_issue"] == True]
        print(f"Found {len(duplicate_issues)} potential near duplicate issues.")
    if "is_outlier_issue" in issues.columns:
        outlier_issues = issues[issues["is_outlier_issue"] == True]
        print(f"Found {len(outlier_issues)} potential outlier issues.")
        
    return issues

def repair_dataset_notebook(original_dataset: Dataset, issues_df: pd.DataFrame, tokenizer: AutoTokenizer) -> Dataset:
    """Repairs the dataset by removing examples with detected issues and re-tokenizes."""
    print("\nRepairing dataset by removing examples with detected issues...")
    
    num_original = len(original_dataset)
    original_indices = np.arange(num_original)
    
    # Initialize mask for keeping examples (True means keep)
    keep_mask = np.ones(num_original, dtype=bool)
    
    # Identify indices with any issue based on the Datalab report index
    problematic_indices = set()
    
    if "is_label_issue" in issues_df.columns:
        label_issues = issues_df[issues_df["is_label_issue"] == True].index
        problematic_indices.update(label_issues)
    if "is_near_duplicate_issue" in issues_df.columns:
        duplicate_issues = issues_df[issues_df["is_near_duplicate_issue"] == True].index
        problematic_indices.update(duplicate_issues)
    if "is_outlier_issue" in issues_df.columns:
        outlier_issues = issues_df[issues_df["is_outlier_issue"] == True].index
        problematic_indices.update(outlier_issues)
    
    # Convert to list and ensure indices are valid
    problematic_indices = [idx for idx in problematic_indices if 0 <= idx < num_original]
    
    # Mark problematic indices for removal
    if problematic_indices:
        # Ensure indices are within bounds
        valid_problematic_indices = [idx for idx in problematic_indices if 0 <= idx < num_original]
        if len(valid_problematic_indices) != len(problematic_indices):
             print(f"Warning: Some issue indices ({len(problematic_indices) - len(valid_problematic_indices)}) were out of bounds for the dataset size ({num_original}).")
        
        if valid_problematic_indices:
             keep_mask[valid_problematic_indices] = False
        num_removed_total = len(valid_problematic_indices)
    else:
        num_removed_total = 0
        
    print(f"- Marked {num_removed_total} total examples with issues for removal.")

    # Apply the keep_mask to get the cleaned data indices
    final_keep_indices = original_indices[keep_mask]
    
    # Select the corresponding rows from the original dataset
    # Important: Use .select() which works efficiently with Hugging Face Datasets
    cleaned_subset = original_dataset.select(final_keep_indices)

    print(f"Dataset size reduced from {num_original} to {len(cleaned_subset)}.")

    # The subset already contains the tokenized columns ('input_ids', 'attention_mask', 'label')
    # No need to re-tokenize if the original_dataset was already tokenized
    print("Cleaned subset retains original tokenization.")
    
    return cleaned_subset

# --- End Cleanlab Helper Functions ---

## Excercise 2

Use cleanlab package to remove problematic examples in the dataset

In [None]:
print("Starting Cleanlab pipeline on naive_dataset['train']...")
#TODO complete the flow to remove problematic examples from dataset
#-----------------------------------------------------------------------------------------------------------------------
if 'naive_dataset' in locals() and naive_dataset and 'train' in naive_dataset:
    naive_train_data = naive_dataset['train']
    
    if "text" not in naive_train_data.column_names:
         print("Warning: 'text' column not found in naive_train_data. Cleanlab works best with raw text.")
         print("Skipping Cleanlab step as raw text is not readily available in the tokenized `naive_dataset`.")
         print("Ideally, Cleanlab should run on generated text *before* tokenization within `expand_dataset`.")
         cleaned_naive_dataset = naive_dataset # Pass through for now
         
    else:
        texts_for_cleanlab =
        labels_for_cleanlab =

        # 2. Get Embeddings
        embeddings_cl = 

        # 3. Get Prediction Probabilities (using embeddings and labels)
        pred_probs_cl = 

        # 4. Find Issues
        issues_df_cl = 

        # 5. Repair Dataset (remove issues)
        # Pass the *original* tokenized dataset subset and the tokenizer
        cleaned_train_subset = 
        
        # Create the final cleaned dataset dictionary
        cleaned_naive_dataset = DatasetDict({
            "train": cleaned_train_subset,
            "test": naive_dataset["test"] # Keep the original test set
        })
        print("Cleaned dataset created: `cleaned_naive_dataset`")
#-----------------------------------------------------------------------------------------------------------------------

else:
    print("Error: `naive_dataset` not found or is empty. Cannot proceed with Cleanlab.")
    cleaned_naive_dataset = None # Indicate failure

In [None]:
trainer, model_path = train_model(cleaned_naive_dataset, tokenizer)
evaluate_model(trainer, tokenizer, model_path)

## Excercise 3
Implement multistep flow to generate better data

- Step 1: Generate description of data from examples
- Step 2: Use descriptions to guide data generation

In [None]:
#TODO complete the description prompt and generation prompts
#-----------------------------------------------------------------------------------------------------------------------
description_prompt = """

"""

generation_with_description_prompt = """

"""
#-----------------------------------------------------------------------------------------------------------------------
def generation_with_description(class_data: Dataset, num_to_generate_per_class: int = 50) -> List[str]:
    batch_size = 20
    generated_texts = []

    label = class_data.features["label"].names[0]
    while len(generated_texts) < num_to_generate_per_class:
        batch = class_data.shuffle().take(batch_size)
        batch_str = "\n".join([text for text in batch["text"]])

        description_prompt_str = description_prompt.format(EXAMPLES=batch_str)
        response = call_llm(description_prompt_str)
        description = extract_tagged_content(response, "description")

        generation_prompt_str = generation_with_description_prompt.format(N=batch_size, LABEL=label, DESCRIPTION=description)
        response = call_llm(generation_prompt_str)
        generated_texts.extend(clean_generated_data(response))

    return generated_texts[:num_to_generate_per_class]

description_dataset = expand_dataset(base_dataset, tokenizer, generation_with_description)
trainer, model_path = train_model(description_dataset, tokenizer)
evaluate_model(trainer, tokenizer, model_path)

## Excercise 4
Implement multistep flow to generate better data

- Step 1: Generate description of data from examples
- Step 2: Use examples and description to generate tips to boost diversity of generated samples
- Step 3: Use descriptions and tips to guide data generation

In [None]:
# reuse description prompt

diversification_prompt = """
You will be given a list of texts in <examples> tags and their descriptions in <description> tags.

<examples>
{EXAMPLES}
</examples>

<description>
{DESCRIPTION}
</description>

Read both description and examples carefully, then come up with creative way to diversify the dataset. Your tips will be given to LLM that will generate new texts.

Your response should be structured as
<tips>
Your tips for diversifying the dataset
</tips>

Be creative and come up with unique tips.
"""
#TODO complete generation_with_tips_prompt and the generation flow
#-----------------------------------------------------------------------------------------------------------------------
generation_with_tips_prompt = """

"""

def generation_with_tips(class_data: Dataset, num_to_generate_per_class: int = 50) -> List[str]:
    batch_size = 20
    generated_texts = []

    label = class_data.features["label"].names[0]
    while len(generated_texts) < num_to_generate_per_class:
        batch = class_data.shuffle().take(batch_size)
        batch_str = "\n".join([text for text in batch["text"]])

        # description prompt steps (fill in examples, generate with llm, extract content)
      



        # description prompt steps (fill in examples and description, generate with llm, extract content)
        



        # description prompt steps (fill all the info, generate with llm, clean generated data from xml tags)
        


        
    return generated_texts[:num_to_generate_per_class]

tips_dataset = expand_dataset(base_dataset, tokenizer, generation_with_tips)
trainer, model_path = train_model(tips_dataset, tokenizer)
evaluate_model(trainer, tokenizer, model_path)


## Excercise 5
Implement multistep flow to generate better data

- Step 1: Generate description of data from examples
- Step 2: Use examples and description to generate tips to boost diversity of generated samples
- Step 3: Use examples to generate "persona" a person description which boost the diversity even more
- Step 4: Use descriptions, tips and persona to guide data generation

In [None]:
query_encoder = SentenceTransformer('Alibaba-NLP/gte-large-en-v1.5', trust_remote_code=True)

#TODO complete the persona prompt and generation flow you can use personas dataset

def embed_query(model, query_text):
    query_embedding = model.encode(query_text, convert_to_tensor=True)
    return query_embedding

def load_and_prepare_embeddings(dataset_name):
    # Load dataset
    dataset = load_dataset(dataset_name)

    embeddings = torch.tensor(dataset['train']['embedding'])

    embeddings = F.normalize(embeddings, p=2, dim=1)

    return dataset, embeddings

def find_closest_persona(query_embedding, embeddings, dataset):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    query_embedding = query_embedding.to(device)
    query_embedding = F.normalize(query_embedding.unsqueeze(0), p=2, dim=1)
    similarities = torch.mm(query_embedding, embeddings.T).squeeze()

    return dataset['train']['persona'][similarities.argmax().item()]


personas_dataset, embeddings = load_and_prepare_embeddings("argilla/FinePersonas-v0.1-clustering-100k")

personas_prompt = """

"""

generation_with_persona_prompt = """
You will play the role of a persona described in <persona> tags.

<persona>
{PERSONA}
</persona>

Your task is to generate {N} texts for label {LABEL}. Here is the description of what should you generate:

<description>
{DESCRIPTION}
</description>

And here are some tips that will help you to generate diverse, creative and unique texts:

<tips>
{TIPS}
</tips>

Your response should be structured as
<results>
<1_text>Your generated text</1_text>
<2_text>Another generated text</2_text>
...
<{N}_text>Last generated text {N}</{N}_text>
</results>

Remember to generate exactly {N} texts.
"""

def generation_with_persona(class_data: Dataset, num_to_generate_per_class: int = 50) -> List[str]:
    batch_size = 20
    generated_texts = []

    label = class_data.features["label"].names[0]

    while len(generated_texts) < num_to_generate_per_class:
        batch = class_data.shuffle().take(batch_size)
        batch_str = "\n".join([text for text in batch["text"]])

        # description prompt steps (fill in examples, generate with llm, extract content)


        # diversification prompt steps (fill in examples and description, generate with llm, extract content)
       


        # persona steps (up to you)
        


        # generation prompt steps (fill all the info, generate with llm, clean generated data from xml tags)
        


    return generated_texts[:num_to_generate_per_class]


persona_dataset = expand_dataset(base_dataset, tokenizer, generation_with_persona)
trainer, model_path = train_model(persona_dataset, tokenizer)
evaluate_model(trainer, tokenizer, model_path)