In [None]:
!nvidia-smi

In [None]:
import os

import re
from dataclasses import dataclass, field
from multiprocessing import cpu_count
from pathlib import Path
from typing import Optional, Union, List, Tuple

In [None]:
import shutil
from datetime import datetime

from datasets import load_dataset, Dataset, DatasetDict, ClassLabel, load_from_disk
from torch.utils.data import DataLoader, default_collate
from transformers import AutoTokenizer

import lightning.pytorch as pl
from lightning.pytorch.utilities.rank_zero import rank_zero_info

#from config import Config, DataModuleConfig, ModuleConfig


In [None]:
from abc import abstractmethod

import torch
import torch.nn as nn

from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.utilities.types import OptimizerLRScheduler, EVAL_DATALOADERS, TRAIN_DATALOADERS
from lightning.pytorch.utilities.types import OptimizerLRScheduler

from torchmetrics.functional import accuracy
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score
from torchmetrics.functional import accuracy
from torchmetrics.classification import MultilabelAccuracy, MultilabelPrecision, MultilabelRecall, MultilabelF1Score
import torchinfo

from transformers import BertForSequenceClassification, AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import get_cosine_schedule_with_warmup

In [None]:
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger, CometLogger, TensorBoardLogger
from lightning.pytorch.profilers import PyTorchProfiler

from dvclive.lightning import DVCLiveLogger

from datamodule import AutoTokenizerDataModule
from module import CustomModel
from utils import create_dirs

import numpy as np

In [None]:
from huggingface_hub import login
import os
token = os.getenv('HUG_FACE_TOKEN')
login(token)

In [None]:
import os

from dataclasses import dataclass, field
from multiprocessing import cpu_count
from pathlib import Path
from typing import Optional, Union, List, Tuple

# ## get root path ## #
'''
this_file = Path(__file__)
this_studio_idx = [
    i for i, j in enumerate(this_file.parents) if j.name.endswith("this_studio")
][0]
'''
this_studio = "./"

@dataclass
class Config:
    cache_dir: str = os.path.join(this_studio, "data")
    log_dir: str = os.path.join(this_studio, "logs")
    ckpt_dir: str = os.path.join(r"E:\bert-twetter-disaster-model-trained", "checkpoints")
    prof_dir: str = os.path.join(this_studio, "logs", "profiler")
    perf_dir: str = os.path.join(this_studio, "logs", "perf")
    seed: int = 59631546


@dataclass
class ModuleConfig:
    model_name: str = "facebook/bart-large-mnli" # change this to use a different pretrained checkpoint and tokenizer
    # model_name: str = "nvidia/NV-Embed-v2" # change this to use a different pretrained checkpoint and tokenizer
    learning_rate: float = 1.0103374612260327e-5
    learning_rate_bert: float = 1.0103374612260327e-5
    learning_rate_lstm: float = 7e-5
    finetuned: str = "checkpoints/twhin-bert-base-finetuned" # change this to use a different pretrained checkpoint and tokenizer
    max_length: int = 128
    attention_probs_dropout: float = 0.1
    classifier_dropout: Optional[float] = None
    warming_steps: int = 100
    focal_gamma: float = 2.0
    
#ModuleConfig.opposing_label_sets = [(0, 1), (10, 11)]
ModuleConfig.opposing_label_sets = None    
@dataclass
class DataModuleConfig:
    dataset_name: str = "sdy623/new_disaster_tweets" # change this to use different dataset
    num_classes: int = 12
    batch_size: int = 8
    train_split: str = "train"
    test_split: str = "test"
    train_size: float = 0.8
    stratify_by_column: str = "label"
    num_workers: int = 0
    

@dataclass
class TrainerConfig:
    accelerator: str = "auto" # Trainer flag
    devices: Union[int, str] = "auto"  # Trainer flag
    strategy: str = "auto"  # Trainer flag
    precision: Optional[str] = "bf16-mixed"  # Trainer flag
    max_epochs: int = 10  # Trainer flag

In [None]:
# model and dataset
model_name = ModuleConfig.model_name
max_length = ModuleConfig.max_length
lr = ModuleConfig.learning_rate
dataset_name = DataModuleConfig.dataset_name
batch_size = DataModuleConfig.batch_size

# paths
cache_dir = Config.cache_dir
log_dir = Config.log_dir
ckpt_dir = Config.ckpt_dir
prof_dir = Config.prof_dir
perf_dir = Config.perf_dir
# creates dirs to avoid failure if empty dir has been deleted
create_dirs([cache_dir, log_dir, ckpt_dir, prof_dir, perf_dir])

# set matmul precision
# see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
torch.set_float32_matmul_precision("high")

In [None]:
lit_datamodule = AutoTokenizerDataModule(
    model_name=model_name,
    dataset_name=dataset_name,
    cache_dir=cache_dir,
    batch_size=batch_size,
    max_length=max_length
)

In [None]:
lit_datamodule.clear_custom_cache()

In [None]:
lit_datamodule.prepare_data()

In [None]:
lit_datamodule.setup("fit")

In [None]:
lit_model = CustomModel(learning_rate=lr, learning_rate_bert=lr, attention_probs_dropout=0.1, classifier_dropout=0, lr_gamma=0.75)

In [None]:
callbacks = [
    ModelCheckpoint(
        dirpath=ckpt_dir,
        monitor="val_f1",
        filename="model",
        save_top_k=3,
        mode="max",
        save_weights_only=True,
    ),
    LearningRateMonitor(logging_interval='step'),

    ]

In [None]:
csvlogger = CSVLogger(
    save_dir=log_dir,
    name="csv-logs",
)

In [None]:
lit_trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    strategy="auto",
    precision=TrainerConfig.precision,
    max_epochs=8,
    logger=[csvlogger, CometLogger(api_key="YOUR_COMET_API_KEY"), DVCLiveLogger(save_dvc_exp=True)],
    #logger=[csvlogger],
    callbacks=callbacks,
    log_every_n_steps=50,
    #strategy=ray.train.lightning.RayDDPStrategy(find_unused_parameters=True),
    #plugins=[ray.train.lightning.RayLightningEnvironment()],
    #callbacks=[ray.train.lightning.RayTrainReportCallback(), LearningRateMonitor(logging_interval='epoch')],
    # [1a] Optionally, disable the default checkpointing behavior
    # in favor of the `RayTrainReportCallback` above.
    #enable_checkpointing=False,
)
hyperparameters = dict(lr=lr)
lit_trainer.logger.log_hyperparams(hyperparameters)

In [None]:
lit_trainer.fit(model=lit_model, datamodule=lit_datamodule)