In [None]:
from huggingface_hub import snapshot_download

checkpoint = "tiiuae/falcon-7b"
weights_location = snapshot_download(repo_id=checkpoint)

In [None]:
from transformers import AutoConfig

In [None]:
config = AutoConfig.from_pretrained(weights_location, trust_remote_code=True)

In [None]:
import torch
# from peft import PeftModel    
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
# import fabric


model_name = "meta-llama/Llama-2-70b-hf"

print(f"Starting to load the model {model_name} into memory")

m = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    torch_dtype=torch.bfloat16,
    device_map={"": 0}
)
# m = PeftModel.from_pretrained(m, adapters_name)
# m = m.merge_and_unload()
# tok = LlamaTokenizer.from_pretrained(model_name)
# tok.bos_token_id = 1

# stop_token_ids = [0]

# print(f"Successfully loaded the model {model_name} into memory")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from energizer.datastores import PandasDataStoreForSequenceClassification
from energizer.estimator import Estimator
from transformers import AutoModelForSequenceClassification
from typing import Dict, List
import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy, F1Score, Precision, Recall
from transformers import AutoModelForSequenceClassification
from energizer.enums import InputKeys, OutputKeys, RunningStage
import numpy as np
from energizer.utilities import move_to_cpu
from lightning.fabric.loggers import TensorBoardLogger
from lightning.fabric import seed_everything
from energizer.callbacks import GradNorm, PytorchTensorboardProfiler

In [None]:
ds = PandasDataStoreForSequenceClassification.load("./agnews_datastore/")

In [None]:
class EstimatorForSequenceClassification(Estimator):

    def train_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> torch.Tensor:
        return self.step(model, batch, metrics, RunningStage.TRAIN)

    def validation_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> torch.Tensor:
        return self.step(model, batch, metrics, RunningStage.VALIDATION)

    def test_step(self, model, batch, batch_idx, loss_fn, metrics: MetricCollection) -> torch.Tensor:
        return self.step(model, batch, metrics, RunningStage.TEST)
    
    def train_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> float:
        return self.epoch_end(output, metrics, RunningStage.TRAIN)

    def validation_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> float:
        return self.epoch_end(output, metrics, RunningStage.VALIDATION)

    def test_epoch_end(self, output: List[np.ndarray], metrics: MetricCollection) -> float:
        return self.epoch_end(output, metrics, RunningStage.TEST)

    def step(
        self,
        model,
        batch: Dict,
        metrics: MetricCollection,
        stage: RunningStage,
    ) -> torch.Tensor:
        
        _ = batch.pop(InputKeys.ON_CPU, None)

        out = model(**batch)
        out_metrics = metrics(out.logits, batch[InputKeys.TARGET])

        if stage == RunningStage.TRAIN:
            logs = {OutputKeys.LOSS: out.loss, **out_metrics}
            self.log_dict({f"{stage}/{k}": v for k, v in logs.items()}, step=self.tracker.global_batch)

        return out.loss
    
    def epoch_end(self, output: List[np.ndarray], metrics: MetricCollection, stage: RunningStage) -> float:
        aggregated_metrics = move_to_cpu(metrics.compute())  # NOTE: metrics are still on device
        aggregated_loss = round(np.mean(output).item(), 6)
        
        logs = {OutputKeys.LOSS: aggregated_loss, **aggregated_metrics}
        self.log_dict({f"{stage}_end/{k}": v for k, v in logs.items()}, step=self.tracker.safe_global_epoch)

        return aggregated_loss

    def configure_metrics(self, *_) -> MetricCollection:
        num_classes = self.model.num_labels
        task = "multiclass"
        # NOTE: you are in charge of moving it to the correct device
        return MetricCollection(
            {
                "accuracy": Accuracy(task, num_classes=num_classes),
                "f1_macro": F1Score(task, num_classes=num_classes, average="macro"),
                "precision_macro": Precision(task, num_classes=num_classes, average="macro"),
                "recall_macro": Recall(task, num_classes=num_classes, average="macro"),
                "f1_micro": F1Score(task, num_classes=num_classes, average="micro"),
                "precision_micro": Precision(task, num_classes=num_classes, average="micro"),
                "recall_micro": Recall(task, num_classes=num_classes, average="micro"),
            }
        ).to(self.device)

In [None]:
seed_everything(42)
model = AutoModelForSequenceClassification.from_pretrained(
    ds.tokenizer.name_or_path,
    id2label=ds.id2label,
    label2id=ds.label2id,
    num_labels=len(ds.labels),
)

estimator = EstimatorForSequenceClassification(
    model, 
    accelerator="gpu", 
    loggers=[TensorBoardLogger("./")],
    callbacks=[GradNorm(2), PytorchTensorboardProfiler("./profiler_logs")],
)

In [None]:
ds.prepare_for_loading()

In [None]:
estimator.fit(
    train_loader=ds.train_loader(passive=True),
    validation_loader=ds.test_loader(),
    num_validation_per_epoch=5,
    limit_train_batches=10,
    limit_validation_batches=10,
)