In [None]:
import os
import random
import torch
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from llama import BasicModelRunner
from utilities import tokenize_and_split_data
from transformers.trainer_callback import TrainerCallback
import matplotlib.pyplot as plt


class MetricsCollector(TrainerCallback):
    """
    Callback to collect metrics during training.

    This callback stores all the logs it receives during training in a list
    called `metrics`. This list can then be used to plot training loss, learning rate,
    and other metrics.
    """

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        """
        Stores the logs received during training.

        This method is called whenever the trainer logs information. It simply
        appends the entire `logs` dictionary to the `metrics` list.

        Args:
          args: Arguments passed to the trainer.
          state: State of the trainer.
          control: Control object for the trainer.
          logs: Dictionary containing the logged metrics. (optional)
          **kwargs: Additional keyword arguments.
        """
        self.metrics.append(logs)


def plot_loss(metrics, output_dir):
    """
    Plots the training loss from the collected metrics and saves the plot.

    This function iterates through the `metrics` list and extracts the `loss` value
    from each dictionary. It then filters out any entries where `loss` is missing
    and plots the remaining values. The plot is saved in the specified `output_dir`.

    Args:
      metrics: List of dictionaries containing training logs.
      output_dir: Directory to save the plot.
    """
    losses = [m.get('loss', None) for m in metrics]  # Use .get() to handle missing keys
    non_none_losses = [loss for loss in losses if loss is not None]
    plt.plot(non_none_losses)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.savefig(os.path.join(output_dir, 'training_loss_plot.png'))
    plt.close()


def plot_learning_rate(metrics, output_dir):
    """
    Plots the learning rate from the collected metrics and saves the plot.

    This function follows the same logic as `plot_loss` to extract and plot the
    learning rate values from the `metrics` list, handling missing entries.
    The plot is saved in the specified `output_dir`.

    Args:
      metrics: List of dictionaries containing training logs.
      output_dir: Directory to save the plot.
    """
    learning_rates = [m.get('learning_rate', None) for m in metrics]
    non_none_learning_rates = [lr for lr in learning_rates if lr is not None]
    plt.plot(non_none_learning_rates)
    plt.xlabel('Iteration')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate')
    plt.savefig(os.path.join(output_dir, 'learning_rate_plot.png'))
    plt.close()


def find_best_hyperparameters():
    model_name = "EleutherAI/pythia-70m"
    use_hf = False
    current_directory = os.getcwd()
    folder_path = os.path.join(current_directory, "content")
    dataset_name = "ai-medical-chatbot_processed.jsonl"
    dataset_path = os.path.join(folder_path, dataset_name)
    
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    training_config = {
        "model": {
            "pretrained_name": model_name,
            "max_length" : 2048
        },
        "datasets": {
            "use_hf": use_hf,
            "path": dataset_path
        },
        "verbose": True
    }
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    train_dataset, test_dataset = tokenize_and_split_data(training_config, tokenizer)

    best_hyperparameters = None
    best_loss = float('inf')

    hyperparameter_space = {
        "learning_rate": [1e-5],
        "num_train_epochs": [1],
        "per_device_train_batch_size": [1],
        "optim": ["adafactor"],
    }

    num_iterations = 1

    for _ in range(num_iterations):
        hyperparameters = {
            "learning_rate": random.choice(hyperparameter_space["learning_rate"]),
            "num_train_epochs": random.choice(hyperparameter_space["num_train_epochs"]),
            "per_device_train_batch_size": random.choice(hyperparameter_space["per_device_train_batch_size"]),
            "optim": random.choice(hyperparameter_space["optim"]),
        }

        output_dir = os.path.join(current_directory, "best_fit")
        os.makedirs(output_dir, exist_ok=True)
        training_args = TrainingArguments(
            learning_rate=hyperparameters["learning_rate"],
            num_train_epochs=hyperparameters["num_train_epochs"],
            per_device_train_batch_size=hyperparameters["per_device_train_batch_size"],
            output_dir=output_dir,
            overwrite_output_dir=False,
            disable_tqdm=False,
            eval_steps=120,
            save_steps=120,
            warmup_steps=1,
            per_device_eval_batch_size=1,
            evaluation_strategy="steps",
            logging_strategy="steps",
            logging_steps=1,
            optim=hyperparameters["optim"],
            gradient_accumulation_steps=4,
            gradient_checkpointing=False,
            load_best_model_at_end=True,
            save_total_limit=1,
            metric_for_best_model="eval_loss",
            greater_is_better=False
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        base_model.to(device)

        trainer = Trainer(
            model=base_model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=test_dataset
        )

        metrics_collector = MetricsCollector()
        trainer.add_callback(metrics_collector)

        training_output = trainer.train()

        eval_results = trainer.evaluate()

        if eval_results["eval_loss"] < best_loss:
            best_loss = eval_results["eval_loss"]
            best_hyperparameters = hyperparameters

        plot_loss(metrics_collector.metrics, output_dir)
        plot_learning_rate(metrics_collector.metrics, output_dir)

    return best_hyperparameters, best_loss


best_hyperparameters, best_loss = find_best_hyperparameters()

print("Best hyperparameters:", best_hyperparameters)
print("Best loss:", best_loss)