In [None]:
!nvidia-smi

## Import Some Requirements on Load Config Data

In [None]:
import os

import re
from dataclasses import dataclass, field
from multiprocessing import cpu_count
from pathlib import Path
from typing import Optional, Union, List, Tuple

## Import Requirements for building Datasets 

In [None]:
import shutil
from datetime import datetime

from datasets import load_dataset, Dataset, DatasetDict, ClassLabel, load_from_disk
from torch.utils.data import DataLoader, default_collate
from transformers import AutoTokenizer

import lightning.pytorch as pl
from lightning.pytorch.utilities.rank_zero import rank_zero_info

#from config import Config, DataModuleConfig, ModuleConfig


## Import Requirements for Denfinting Model

In [None]:
from abc import abstractmethod

import torch
import torch.nn as nn

from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.utilities.types import OptimizerLRScheduler, EVAL_DATALOADERS, TRAIN_DATALOADERS
from lightning.pytorch.utilities.types import OptimizerLRScheduler

from torchmetrics.functional import accuracy
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score
from torchmetrics.functional import accuracy
from torchmetrics.classification import MultilabelAccuracy, MultilabelPrecision, MultilabelRecall, MultilabelF1Score
import torchinfo

from transformers import BertForSequenceClassification, AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import get_cosine_schedule_with_warmup

In [None]:
import torch

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger, CometLogger, TensorBoardLogger
from lightning.pytorch.profilers import PyTorchProfiler

from dvclive.lightning import DVCLiveLogger

from datamodule import AutoTokenizerDataModule
from module import CustomModel
from utils import create_dirs

import numpy as np

In [None]:
from huggingface_hub import login
import os
token = os.getenv('HUG_FACE_TOKEN')
login(token)

## Configs

In [None]:
this_kaggle = "./"

@dataclass
class Config:
    cache_dir: str = os.path.join(this_kaggle, "data")
    log_dir: str = os.path.join(this_kaggle, "logs")
    ckpt_dir: str = os.path.join(this_kaggle, "checkpoints")
    prof_dir: str = os.path.join(this_kaggle, "logs", "profiler")
    perf_dir: str = os.path.join(this_kaggle, "logs", "perf")
    seed: int = 59631546


@dataclass
class ModuleConfig:
    model_name: str = "facebook/bart-large-mnli" # change this to use a different pretrained checkpoint and tokenizer
    # model_name: str = "nvidia/NV-Embed-v2" # change this to use a different pretrained checkpoint and tokenizer
    learning_rate: float = 1.0103374612260327e-5
    learning_rate_bert: float = 1.0103374612260327e-5
    learning_rate_lstm: float = 7e-5
    finetuned: str = "checkpoints/twhin-bert-base-finetuned" # change this to use a different pretrained checkpoint and tokenizer
    max_length: int = 128
    attention_probs_dropout: float = 0.1
    classifier_dropout: Optional[float] = None
    warming_steps: int = 100
    focal_gamma: float = 2.0

    #opposing_label_sets: List[Tuple[int, int]] = field(default_factory=lambda: [(0, 1), (10, 11)])
#ModuleConfig.opposing_label_sets = [(0, 1), (10, 11)]
ModuleConfig.opposing_label_sets = None
@dataclass
class DataModuleConfig:
    dataset_name: str = "sdy623/new_disaster_tweets" # change this to use different dataset
    num_classes: int = 12
    batch_size: int = 8
    train_split: str = "train"
    test_split: str = "test"
    train_size: float = 0.8
    stratify_by_column: str = "label"
    num_workers: int = 0

@dataclass
class TrainerConfig:
    accelerator: str = "auto" # Trainer flag
    devices: Union[int, str] = "auto"  # Trainer flag
    strategy: str = "auto"  # Trainer flag
    precision: Optional[str] = "bf16"  # Trainer flag
    max_epochs: int = 7  # Trainer flag

In [None]:
# model and dataset
model_name = ModuleConfig.model_name
max_length = ModuleConfig.max_length
lr = ModuleConfig.learning_rate
dataset_name = DataModuleConfig.dataset_name
batch_size = DataModuleConfig.batch_size

# paths
cache_dir = Config.cache_dir
log_dir = Config.log_dir
ckpt_dir = Config.ckpt_dir
prof_dir = Config.prof_dir
perf_dir = Config.perf_dir
# creates dirs to avoid failure if empty dir has been deleted
create_dirs([cache_dir, log_dir, ckpt_dir, prof_dir, perf_dir])

In [None]:
def tokenize_text(
    batch: dict,
    *,
    model_name: str,
    cache_dir: Union[str, Path],
    truncation: bool = True,  # leave as True if dataset has sequences that exceed the model's max sequence length
    padding: bool = "max_length",  # pad so that all tensors are of the same dimensions
    max_length: int = 512,
):
    """
    Notes:
        https://huggingface.co/docs/transformers/v4.38.2/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained
    """
    #print(batch["text"])
    tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
    text = batch if isinstance(batch, str) else batch["text"]  # allow for inference input as raw text
    if isinstance(text, list):
        text = [str(t) for t in text]
    else:
        text = str(text)
    final_token = tokenizer(text, truncation=truncation, padding=padding, return_tensors="pt", return_token_type_ids=True, max_length=max_length)
    #print(tokenizer.model_max_length)
    '''
    for i in range(len(final_token['input_ids'])):
                assert len(final_token['input_ids'][i]) == max_length, f"input_ids length mismatch at index {i}, expected {max_length} but got {len(final_token['input_ids'][i])}"
                assert len(final_token['attention_mask'][i]) == max_length, f"attention_mask length mismatch at index {i}, expected {max_length} but got {len(final_token['attention_mask'][i])}"
                if 'token_type_ids' in final_token:
                    assert len(final_token['token_type_ids'][i]) == max_length, f"token_type_ids length mismatch at index {i}, expected {max_length} but got {len(final_token['token_type_ids'][i])}"
    #assert len(segment_ids) == max_seq_length
    '''
    
    return final_token


## Text cleaning tools

## Prepare The Data Processing Class and Method

In [None]:
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return max(lr_factor, 2e-6)

In [None]:
from torch.utils.data import DataLoader, default_collate

class AutoTokenizerDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_name: str = DataModuleConfig.dataset_name,
        cache_dir: Union[str, Path] = Config.cache_dir,
        model_name: str = ModuleConfig.model_name,
        max_length: int = ModuleConfig.max_length,
        num_labels: int = DataModuleConfig.num_classes,
        columns: list = ["input_ids", "attention_mask", "label", "token_type_ids"],
        batch_size: int = DataModuleConfig.batch_size,
        train_size: float = DataModuleConfig.train_size,
        stratify_by_column: str = DataModuleConfig.stratify_by_column,
        train_split: str = DataModuleConfig.train_split,
        test_split: str = DataModuleConfig.test_split,
        num_workers: int = DataModuleConfig.num_workers,
        seed: int = Config.seed,
    ) -> None:
        """a custom PyTorch Lightning LightningDataModule to tokenize text datasets

        Args:
            dataset_name: the name of the dataset as given on HF datasets
            cache_dir: corresponds to HF datasets.load_dataset
            model_name: the name of the model and accompanying tokenizer
            num_labels: the number of labels
            columns: the list of column names to pass to the HF dataset's .set_format method
            batch_size: the batch size to pass to the PyTorch DataLoaders
            train_size: the size of the training data split to pass to .train_test_split
            stratify_by_column: column name of labels to be used to perform stratified split of data
            train_split: the name of the training split as given on HF Hub
            test_split: the name of the test split as given on HF Hub
            num_workers: corresponds to torch.utils.data.DataLoader
            seed: the seed used in lightning.pytorch.seed_everything

        Notes:

        """
        super().__init__()

        self.dataset_name = dataset_name
        self.cache_dir = cache_dir
        self.model_name = model_name
        self.max_length = max_length
        self.train_size = train_size
        self.train_split = train_split
        self.test_split = test_split
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed
        self.num_labels = num_labels
        self.columns = columns
        self.stratify_by_column = stratify_by_column

    def clear_custom_cache(self):
        """Custom method to clear cache"""
        if os.path.exists(self.cache_dir):
            shutil.rmtree(self.cache_dir)  # Remove the directory

    def prepare_data(self) -> None:
        """
        Notes:
            https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data
        """
        pl.seed_everything(self.seed)
        # disable parrelism to avoid deadlocks
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        if not os.path.isdir(self.cache_dir):
            os.mkdir(self.cache_dir)

        cache_dir_is_empty = len(os.listdir(self.cache_dir)) == 0

        if cache_dir_is_empty:
            rank_zero_info(f"[{str(datetime.now())}] Downloading dataset.")
            load_dataset(self.dataset_name, cache_dir=self.cache_dir, use_auth_token=True)
        else:
            rank_zero_info(
                f"[{str(datetime.now())}] Data cache exists. Loading from cache."
            )

    def setup(self, stage: str) -> None:
        """
        Notes:
            https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup
        """
        if stage == "fit" or stage is None:
            # load and split
            dataset = load_dataset(
                self.dataset_name, cache_dir=self.cache_dir
            )
            print(dataset)
            #dataset = dataset.train_test_split(
            #    train_size=self.train_size, stratify_by_column=self.stratify_by_column
            #)
            # prep train
            self.train_data = dataset["train"].map(
                lambda example: {'text': example['text'], 
                                'label': list(example.values())[2:]},
                batched=False
            )            
            self.train_data = self.train_data.map(
                tokenize_text,
                batched=True,
                batch_size=1024,
                fn_kwargs={"model_name": self.model_name, "cache_dir": self.cache_dir, "max_length": self.max_length},
            )
            
            self.train_data.set_format("torch", columns=self.columns, output_all_columns=True)            
            # prep val
            self.val_data = dataset["test"].map(
                lambda example: {'text': example['text'], 
                                'label': list(example.values())[2:]},
                batched=False
            )
            
            self.val_data = self.val_data.map(
                tokenize_text,
                batched=True,
                batch_size=1024,
                fn_kwargs={"model_name": self.model_name, "cache_dir": self.cache_dir, "max_length": self.max_length},
            )

            self.val_data.set_format("torch", columns=self.columns)            
            # free mem from unneeded dataset obj
            del dataset
        if stage == "test" or stage is None:
            self.test_data = load_dataset(
                self.dataset_name, split=self.test_split, cache_dir=self.cache_dir
            )
            self.test_data.map(
                tokenize_text,
                batched=True,
                batch_size=512,
                fn_kwargs={"model_name": self.model_name, "cache_dir": self.cache_dir, "max_length": self.max_length},
            )
            self.test_data.set_format("torch", columns=self.columns)

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        """
        Notes:
            https://lightning.ai/docs/pytorch/stable/data/datamodule.html#train-dataloader
        """
        return DataLoader(
            self.train_data,
            num_workers=self.num_workers,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=lambda batch: default_collate([item for item in batch if item is not None])

        )

    def val_dataloader(self) -> EVAL_DATALOADERS:
        """
        Notes:
            https://lightning.ai/docs/pytorch/stable/data/datamodule.html#val-dataloader
        """
        return DataLoader(
            self.val_data,
            num_workers=self.num_workers,
            batch_size=self.batch_size,
        )

    def test_dataloader(self) -> EVAL_DATALOADERS:
        """
        Notes:
            https://lightning.ai/docs/pytorch/stable/data/datamodule.html#test-dataloader
        """
        return DataLoader(
            self.test_data,
            num_workers=self.num_workers,
            batch_size=self.batch_size,
        )

## The Defination of Model Training

In [None]:
@abstractmethod
class EncoderBase(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.hidden_size = self.encoder.config.hidden_size


class BARTEmbeddings(EncoderBase):
    def __init__(self, model_name,
                 attention_dropout: Optional[float] = None):
        super().__init__(model_name)
        self.hidden_size = self.encoder.config.hidden_size
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):

        outputs = self.encoder(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return outputs.last_hidden_state

class BERTEmbeeding(EncoderBase):
    def __init__(self, model_name,
                 attention_probs_dropout: Optional[float] = None):
        super().__init__(model_name)
        assert attention_probs_dropout is None or 0 <= attention_probs_dropout <= 1, \
            "attention_probs_dropout must be between 0 and 1 or None"
        self.hidden_size = self.encoder.config.hidden_size
        if attention_probs_dropout:
            self.encoder = AutoModel.from_pretrained(
                model_name,
                attention_probs_dropout_prob=attention_probs_dropout
            )
    def forward(self, input_ids, 
                attention_mask=None, 
                token_type_ids=None,
                position_ids=None,
                head_mask=None,):
        
        outputs = self.encoder(input_ids=input_ids,
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids,
                               position_ids=position_ids,
                               head_mask=head_mask,)
        return outputs.last_hidden_state

class ClassifierBase(nn.Module):
    def __init__(self, input_dim, num_labels):
        super().__init__()
        self.input_dim = input_dim
        self.num_labels = num_labels

    def forward(self, hidden_states):
        raise NotImplementedError("This method should be overridden in subclasses")

class BertLinearClassificationHead(ClassifierBase):
    """Head for sentence-level classification tasks."""
    def __init__(
            self,
            input_dim: int,
            num_labels: int,
            opposing_label_sets: List[Tuple[int, int]] = None,
            classifier_dropout: Optional[float] = None
    ):
        assert classifier_dropout is None or 0 <= classifier_dropout <= 1, \
            "pooler_dropout must be between 0 and 1 or None"
        super().__init__(input_dim, num_labels)
        self.opposing_label_sets = opposing_label_sets  # List of tuples with opposing label indices
        self.dropout = nn.Dropout(p=classifier_dropout) if classifier_dropout else None
        self.linear = nn.Linear(input_dim, num_labels)

    def forward(self, hidden_states):
        #pooled_output = torch.mean(hidden_states, dim=1)
        cls_output = hidden_states[:, 0, :]
        if self.dropout:
            cls_output = self.dropout(cls_output)
        logits = self.linear(cls_output)

        # Apply Softmax to opposing labels
        if self.opposing_label_sets is not None:
            for label_set in self.opposing_label_sets:
                logits[:, label_set] = torch.softmax(logits[:, label_set], dim=1).to(logits.dtype)

        # Apply Sigmoid to all logits for multi-label outputs
        return logits

class ClassificationHEAD(ClassifierBase):
    def __init__(self, input_dim, num_labels, opposing_label_sets: List[Tuple[int, int]]=None):
        super().__init__(input_dim, num_labels)
        self.opposing_label_sets = opposing_label_sets  # List of tuples with opposing label indices
        self.linear = nn.Linear(input_dim, num_labels)

    def forward(self, hidden_states):
        #pooled_output = torch.mean(hidden_states, dim=1)
        cls_output = hidden_states[:, 0, :]
        logits = self.linear(cls_output)

        # Apply Softmax to opposing labels
        if self.opposing_label_sets is not None:
            for label_set in self.opposing_label_sets:
                logits[:, label_set] = torch.softmax(logits[:, label_set], dim=1)
        
        # Apply Sigmoid to all logits for multi-label outputs
        return logits

class LSTMClassificationHEAD(ClassifierBase):
    def __init__(self, input_dim, num_labels, opposing_label_sets: List[Tuple[int, int]]=None):
        super().__init__(input_dim, num_labels)
        self.opposing_label_sets = opposing_label_sets  # List of tuples with opposing label indices
        self.lstm = nn.LSTM(input_dim, input_dim, batch_first=True)
        self.linear = nn.Linear(input_dim, num_labels)
        #self.multiAttn = nn.MultiheadAttention(input_dim, input_dim)

    def forward(self, last_hidden_state):
        out, _ = self.lstm(last_hidden_state, None)
        
        sequence_output = out[:, -1, :]
        logits = self.linear(sequence_output)
        
        # Apply Softmax to opposing labels
        '''
        if self.opposing_label_sets is not None:
            for label_set in self.opposing_label_sets:
                logits[:, label_set] = torch.softmax(logits[:, label_set], dim=1)
        '''
        # Apply Sigmoid to all logits for multi-label outputs
        return logits

class CustomModel(pl.LightningModule):
    def __init__(self,
        model_name: str = ModuleConfig.model_name,
        num_classes: int = DataModuleConfig.num_classes,  # set according to the finetuning dataset
        input_key: str = "input_ids",  # set according to the finetuning dataset
        label_key: str = "label",  # set according to the finetuning dataset
        mask_key: str = "attention_mask",  # set according to the model output object
        output_key: str = "logits",  # set according to the model output object
        loss_key: str = "loss",  # set according to the model output object
        attention_probs_dropout: Optional[float] = ModuleConfig.attention_probs_dropout,
        classifier_dropout: Optional[float] = ModuleConfig.classifier_dropout,
        learning_rate: float = ModuleConfig.learning_rate,
        learning_rate_bert: float = ModuleConfig.learning_rate_bert,
        learning_rate_lstm: float = ModuleConfig.learning_rate_lstm,
        lr_gamma: float = 0.76825,
        opposing_label_sets: List[Tuple[int, int]] = None,
        warmup: int = ModuleConfig.warming_steps,):

        super().__init__()
        self.save_hyperparameters()

        self.input_key = input_key
        self.label_key = label_key
        self.mask_key = mask_key
        self.output_key = output_key
        self.loss_key = loss_key
        self.num_classes = num_classes
        self.lr_gamma= lr_gamma
        
        self.encoder = BARTEmbeddings(model_name, attention_probs_dropout)

        self.classifier = LSTMClassificationHEAD(self.encoder.hidden_size, num_classes, opposing_label_sets)
        #for name, param in self.encoder.encoder.named_parameters():
        #    print(name)
        #for param in self.encoder.encoder.parameters():
        #    param.requires_grad = False
        #    # Print all layers of BERT
        # Unfreeze the last four layers in BERT
        #for layer in self.encoder.encoder.encoder.layer[-4:]:
        #    for param in layer.parameters():
        #        param.requires_grad = True
                
        self.learning_rate = learning_rate
        self.learning_rate_bert = learning_rate_bert
        self.learning_rate_lstm = learning_rate_lstm

        self.opposing_label_sets = opposing_label_sets
        self.criterion = nn.BCEWithLogitsLoss()  # Use BCEWithLogits for multi-label setting
        #self.criterion = FocalLoss(gamma=ModuleConfig.focal_gamma, reduction='sum')
        # Metrics
        self.accuracy = MultilabelAccuracy(num_labels=num_classes, average='micro')
        self.precision = MultilabelPrecision(num_labels=num_classes, average='micro')
        self.recall = MultilabelRecall(num_labels=num_classes, average='micro')
        self.f1_score = MultilabelF1Score(num_labels=num_classes, average='micro')
        self.macro_f1_score = MultilabelF1Score(num_labels=num_classes, average='macro')

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        #print(input_ids)
        #print(attention_mask)
        #print(token_type_ids)
        hidden_states = self.encoder(input_ids, attention_mask)
        logits = self.classifier(hidden_states)
        return logits

    def training_step(self, batch, batch_idx):
        #print(batch)
        #input_ids, attention_mask, token_type_ids, labels = batch
        #print(input_ids)
        #print(attention_mask)
        #print(token_type_ids)
        #print(labels)
        # logits = self(input_ids=batch[self.input_key], attention_mask=batch[self.mask_key], token_type_ids=batch["token_type_ids"])
        logits = self(input_ids=batch[self.input_key], attention_mask=batch[self.mask_key])
        loss = self.criterion(logits, batch[self.label_key].float())
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        logits = self(input_ids=batch[self.input_key], attention_mask=batch[self.mask_key])
        loss = self.criterion(logits, batch[self.label_key].float())
        # preds = (torch.sigmoid(logits) > 0.5).int()
        
        all_opposing_labels = []
        if self.opposing_label_sets is not None:
            for label_set in self.opposing_label_sets:
                all_opposing_labels.extend(label_set)
                logits[:, label_set] = torch.softmax(logits[:, label_set].float(), dim=1).to(logits.dtype)

        non_opposing_labels = [i for i in range(self.num_classes) if i not in all_opposing_labels]
        if non_opposing_labels:
            logits[:, non_opposing_labels] = torch.sigmoid(logits[:, non_opposing_labels])
        
        preds = (logits > 0.5).int()

        # Calculate metrics
        acc = self.accuracy(preds, batch[self.label_key])
        prec = self.precision(preds, batch[self.label_key])
        rec = self.recall(preds, batch[self.label_key])
        f1 = self.f1_score(preds, batch[self.label_key])
        marco_f1 = self.macro_f1_score(preds, batch[self.label_key])

        # Log metrics
        self.log("val_loss", loss, on_epoch=True)
        self.log("val_accuracy", acc, on_epoch=True)
        self.log("val_precision", prec, on_epoch=True)
        self.log("val_recall", rec, on_epoch=True)
        self.log("val_f1", f1, on_epoch=True)
        self.log("val_macro_f1", marco_f1, on_epoch=True)

        return {"val_loss": loss, "val_accuracy": acc, "val_precision": prec, "val_recall": rec, "val_f1": f1, "val_macro_f1": marco_f1,}
    
    
    def predict_step(
        self, sequence: str, threshold: float = 0.5, cache_dir: Union[str, Path] = Config.cache_dir
        ) -> str:
            """
            Notes:
                https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#inference
            """
            batch = tokenize_text(sequence, model_name=self.model_name, cache_dir=cache_dir)
            # autotokenizer may cause tokens to lose device type and cause failure
            batch = batch.to(self.device)
            outputs = self.model(batch[self.input_key])
            predicted_labels = (outputs >= threshold).int()
            labels = {
                0: "class_a",
                1: "class_b",
                2: "class_c",
                3: "class_d",
                4: "class_e",
                5: "class_f"
            }
            
            # Return the label corresponding to the predicted index
            return labels.get(predicted_labels, "Unknown")
            #logits = outputs[self.output_key]
            #predicted_label_id = torch.argmax(logits)
            #labels = {0: "negative", 1: "positive"}
            return labels[predicted_label_id.item()]

    def test_step(self, batch, batch_idx) -> None:
        """
        Notes:
            https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#testing
        """
        outputs = self(
            batch[self.input_key],
            attention_mask=batch[self.mask_key],
        )

        logits = outputs
        predicted_labels = torch.argmax(logits, 1)
        return predicted_labels.cpu().numpy()
        
    def configure_optimizers(self) -> OptimizerLRScheduler:
        
        bert_params = [param for name, param in self.named_parameters() if "bert" in name]
        lstm_params = [param for name, param in self.named_parameters() if "lstm" in name or "linear" in name]
        
        '''
        optimizer = torch.optim.AdamW([
            {"params": bert_params, "lr": self.learning_rate_bert},
            {"params": lstm_params, "lr": self.learning_rate_lstm}
        ])
        '''
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

        #optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.92, weight_decay=1.28e-5)
        #print(total_steps)
        self.lr_scheduler = CosineWarmupScheduler(
            optimizer, warmup=self.hparams.warmup, max_iters=self.trainer.estimated_stepping_batches
        )
        
        fixed_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, total_iters=3)
        # Define the cosine annealing learning rate scheduler for the remaining epochs
        cosine_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=5e-7)

        reduce_lr_on_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.35, patience=3, verbose=True)
        # Create the exp scheduler
        #exp_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.59825)
        exp_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.lr_gamma)
        # Combine the schedulers
        combined_scheduler = torch.optim.lr_scheduler.ChainedScheduler(schedulers=[fixed_lr_scheduler, exp_scheduler])

        return optimizer
    
    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.lr_scheduler.step()  # Step per iteration
        '''
    def configure_optimizers(self) -> OptimizerLRScheduler:
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        '''

## The code runs

In [None]:
Config

In [None]:
model_name = ModuleConfig.model_name
lr = ModuleConfig.learning_rate
dataset_name = DataModuleConfig.dataset_name
batch_size = DataModuleConfig.batch_size

# paths
cache_dir = Config.cache_dir
log_dir = Config.log_dir
ckpt_dir = Config.ckpt_dir
prof_dir = Config.prof_dir
perf_dir = Config.perf_dir
# creates dirs to avoid failure if empty dir has been deleted
create_dirs([cache_dir, log_dir, ckpt_dir, prof_dir, perf_dir])
torch.set_float32_matmul_precision("high")

In [None]:
lit_datamodule = AutoTokenizerDataModule(
    model_name=model_name,
    dataset_name=dataset_name,
    cache_dir=cache_dir,
    batch_size=batch_size,
)

In [None]:
lit_datamodule.clear_custom_cache()

In [None]:
lit_datamodule.prepare_data()

In [None]:
lit_datamodule.setup("fit")

In [None]:
#lit_datamodule.setup("test")

In [None]:
lit_model = CustomModel(learning_rate=lr, learning_rate_bert=lr, 
                    lr_gamma=0.75, attention_probs_dropout=0.1, 
                    classifier_dropout=0)

In [None]:
callbacks = [
    ModelCheckpoint(
        dirpath=ckpt_dir,
        monitor="val_f1",
        filename="model",
        save_top_k=3,
        mode="max",
        save_weights_only=True,
    ),
    LearningRateMonitor(logging_interval='step'),
]

In [None]:
logger = CSVLogger(
    save_dir=log_dir,
    name="csv-logs",
)

In [None]:
lit_trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    strategy="auto",
    precision="bf16-mixed",
    max_epochs=8,
    deterministic=True,
    logger=[logger, CometLogger(api_key="YOUR_COMET_API_KEY"), DVCLiveLogger(save_dvc_exp=True)],
    callbacks=callbacks,
    log_every_n_steps=50,
)

In [None]:
lit_trainer.fit(model=lit_model, datamodule=lit_datamodule)

In [None]:
!bash -c cat {lit_trainer.log_dir}/metrics.csv

In [None]:
!cat {lit_trainer.log_dir}/metrics.csv

In [None]:
!ls -all checkpoints/

## Load model for inf

In [None]:
lit_trainer.checkpoint_callback.best_model_path

In [None]:
model = lit_model.load_from_checkpoint("./checkpoints/best.ckpt")

In [None]:
from tqdm import tqdm

In [None]:
model

In [None]:
torchinfo.summary(model, depth=3)

In [None]:
model['']

In [None]:
model.hparams

In [None]:
model.get('callbacks')

In [None]:
import pytorch_lightning

In [None]:
torch_org_model = torch.load("./checkpoints/best.ckpt")

In [None]:
# Prepare test_dataloader
test_dataloader = lit_datamodule.test_dataloader()

# Eval mode
model.eval()

# Prepare list for storing inf results
predictions = []

# Disable grad for inf
with torch.no_grad():
    for batch in tqdm(test_dataloader):
        #print(batch)
        input_ids = batch[model.input_key].to(model.device)
        attention_mask = batch[model.mask_key].to(model.device)
        #token_type_ids = batch["token_type_ids"]
        
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        
        #print(torch.sigmoid(logits[:, 1]))
        preds = (torch.sigmoid(logits)[:, 1]> 0.5).int()

        predictions.extend(preds.cpu().numpy())
