<a href="https://colab.research.google.com/github/pszemraj/pubmed-text-classification/blob/analysis/colab/notebooks/transformers_textclassifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training a TextClassifier with lightning-flash

- original `lightning-flash` tutorial on text classification from the package [docs](https://lightning-flash.readthedocs.io/en/stable/reference/text_classification.html)
- this notebook goes over training [BERT uncased](https://huggingface.co/bert-base-uncased) with lightning-flash on the pubmed article dataset

In [None]:
!nvidia-smi

Tue Apr 26 10:55:37 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    22W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# setup

In [None]:
%%capture
#@markdown set up auto-formatting of cells in notebook

from IPython.display import HTML, display


def set_css():
    display(
        HTML(
            """
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  """
        )
    )
get_ipython().events.register("pre_run_cell", set_css)

In [None]:
#@markdown set up logfile
import logging
_das_logfile = "LOGFILE_lf_tfcls_ml4hc_p2.log"
logging.basicConfig(
    filename=_das_logfile,
    filemode="a",
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    level=logging.INFO,
)
# USAGE logging.info("this message will be recorded")

In [None]:
#@title mount drive, define root folder
from google.colab import drive
from pathlib import Path
drive_base_str = '/content/drive'
drive.mount(drive_base_str)


Mounted at /content/drive


In [None]:
drive_head_dir = Path(drive_base_str)

root_dir = "/content/drive/MyDrive/ETHZ-2022-S/ML-healthcare-projects/project2/transformers" #@param {type:"string"}
root_dir = Path(root_dir)
if not root_dir.exists():
    print(f"{root_dir.resolve()} does not exist, creating generic folder in drive root")
    root_dir = drive_head_dir / "transformer-text-classifier"
    root_dir.mkdir(exist_ok=True)

print(f"NOTE: all files will be stored using this as root:\n\t{root_dir}")


NOTE: all files will be stored using this as root:
	/content/drive/MyDrive/ETHZ-2022-S/ML-healthcare-projects/project2/transformers


## installs

In [None]:
!pip install -U 'lightning-flash[text]' -q
!pip install -U clean-text[gpl] -q
!pip install torchmetrics==0.7.3 -q
# if it seems there are issues with flash and installs, investigate torchmetrics

[K     |████████████████████████████████| 1.1 MB 14.8 MB/s 
[K     |████████████████████████████████| 135 kB 65.4 MB/s 
[K     |████████████████████████████████| 582 kB 39.9 MB/s 
[K     |████████████████████████████████| 408 kB 43.2 MB/s 
[K     |████████████████████████████████| 79 kB 11.2 MB/s 
[K     |████████████████████████████████| 1.2 MB 53.4 MB/s 
[K     |████████████████████████████████| 4.0 MB 40.9 MB/s 
[K     |████████████████████████████████| 325 kB 63.0 MB/s 
[K     |████████████████████████████████| 1.1 MB 63.3 MB/s 
[K     |████████████████████████████████| 77 kB 7.6 MB/s 
[K     |████████████████████████████████| 136 kB 74.4 MB/s 
[K     |████████████████████████████████| 212 kB 76.4 MB/s 
[K     |████████████████████████████████| 596 kB 67.6 MB/s 
[K     |████████████████████████████████| 127 kB 72.8 MB/s 
[K     |████████████████████████████████| 749 kB 56.2 MB/s 
[K     |████████████████████████████████| 1.5 MB 44.4 MB/s 
[K     |██████████████████

## functions

In [None]:
from datetime import datetime
#@markdown define `get_timestamp()`
def get_timestamp():
    """This is, as they say, a utility function"""
    return datetime.now().strftime("%b-%d-%Y_t-%H")

In [None]:
from pathlib import Path
from tqdm.auto import tqdm
from cleantext import clean

import pandas as pd
import re

#@markdown define `process_txt_data(txt_datadir:str, verbose=False)`
def fix_parathesis(text:str, 
                   re_str=r"(?<=[([]) +| +(?=[)\]])"):
    """
    fix_parathesis - does the following:
                        input text "I like ( perhaps even love ) to eat beans."
                        output text "I like (perhaps even love) to eat beans."
    """
    fixed_text = re.sub(re_str, "", text)

    return fixed_text

def fix_punct_spaces(input_text:str):
    """
    fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"

    :input_text: str, required, input string to be corrected

    Returns
    fixed_text - str, corrected string
    """

    fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,;:]+)*)\s*")
    input_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), input_text)
    input_text = input_text.replace(" ' ", "'")
    fixed_text = input_text.replace(' " ', '"')
    return fix_parathesis(fixed_text.strip())

def custom_clean(ugly_txt, lowercase=True):

    return clean(ugly_txt, lower=lowercase)

def process_txt_data(txt_datadir:str, lowercase=True,
                     verbose=False):
    """read each downloaded txt file into pandas, convert to a dataframe, and save as a CSV"""
    txt_datadir = Path(txt_datadir)
    text_files = [f for f in txt_datadir.iterdir() if f.is_file() and f.suffix == '.txt']
    csv_paths = []

    for txt_path in tqdm(text_files, total=len(text_files)):

        df = pd.read_csv(txt_path, 
                         skiprows=1, 
                         delimiter='\t',
                         header=None,
                         on_bad_lines='skip',
                         engine='python',
                    ).convert_dtypes()
        df.columns = ['target', 'description']
        df.dropna(inplace=True)
        df.reset_index(drop=True, inplace=True)
        df["description_cln"] = df["description"].apply(clean, lower=lowercase)
        df["description_cln"] = df["description_cln"].apply(fix_punct_spaces)
        _csv_out_path = txt_path.with_suffix('.csv')
        df.to_csv(_csv_out_path, index=False)
        csv_paths.append(_csv_out_path)

    if verbose:
        print(f"processed and returning:\n\t{[f.name for f in csv_paths]}")

    return csv_paths


## define params, model type


- nbew: https://huggingface.co/voidism/diffcse-bert-base-uncased-trans


In [None]:
#@title nn training parameters
import torch
NUM_EPOCHS =  12#@param {type:"integer"}
BATCH_SIZE =  32#@param {type:"integer"}
MAX_LEN = 256 #@param ["128", "256", "512", "1024"] {type:"raw"}
TRAIN_FP16 = True #@param {type:"boolean"}
TRAIN_STRATEGY = "no_freeze" #@param ["freeze", "freeze_unfreeze", "no_freeze", "full_train"]
LR_INITIAL =  1e-4#@param {type:"number"}
LR_SCHEDULE = "reducelronplateau" #@param ["constantlr", "reducelronplateau"]
WEIGHT_DECAY = 0.05 #@param ["0", "0.01", "0.05", "0.1"] {type:"raw"}

#@markdown `UNFREEZE_EPOCH` is the epoch to unfreeze all model layers, only used for 
#@markdown `TRAIN_STRATEGY = "freeze_unfreeze"`
UNFREEZE_EPOCH =  4#@param {type:"integer"}

#@markdown `prajjwal1/bert-medium` is default 
hf_tag = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" #@param ["bert-base-uncased", "bert-large-uncased", "dmis-lab/biobert-v1.1", "kamalkraj/bioelectra-base-discriminator-pubmed", "bert-base-cased", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", "allenai/longformer-base-4096", "facebook/bart-base", "albert-base-v2", "yikuan8/Clinical-BigBird", "yikuan8/Clinical-Longformer", "bionlp/bluebert_pubmed_mimic_uncased_L-12_H-768_A-12", "microsoft/xtremedistil-l12-h384-uncased", "sultan/BioM-BERT-PubMed-PMC-Large", "sultan/BioM-ELECTRA-Base-Discriminator", "voidism/diffcse-bert-base-uncased-trans", "facebook/bart-large"]

if not torch.cuda.is_available():
    print("cuda not available, setting var TRAIN_FP16 to False.")
    TRAIN_FP16=False


if TRAIN_STRATEGY =="freeze_unfreeze":
    assert NUM_EPOCHS > UNFREEZE_EPOCH > 0, "Please configure params such that NUM_EPOCHS > UNFREEZE_EPOCH > 0"
    
session_params = {
    "NUM_EPOCHS":NUM_EPOCHS,
    "BATCH_SIZE":BATCH_SIZE,
    "MAX_INPUT_LENGTH":MAX_LEN,
    "TRAIN_FP16":TRAIN_FP16,
    "TRAIN_STRATEGY":TRAIN_STRATEGY,
    "LR_SCHEDULE":LR_SCHEDULE,
    "LR_INITIAL":LR_INITIAL,
    "WEIGHT_DECAY":WEIGHT_DECAY,
    "UNFREEZE_EPOCH":UNFREEZE_EPOCH,
    "hf_tag":hf_tag,

}
logging.info(f"\n\nParameters for a new session:\n\t{session_params}")

## load data


In [None]:
#@markdown this cell uses the `bash` shell commands (inside jupyter) to download 
#@markdown the relevant dataset in the dropbown to the runtime. the variable `data_dir`
#@markdown points to the directory containing the files.
from pathlib import Path
url_project2data_full = "https://www.dropbox.com/sh/xn85zbn7brqq35y/AAB80_k_OWttvnSjJRFgEFMca?dl=1" #@param {type:"string"}
url_project2data_20k = "https://www.dropbox.com/sh/tr0jyps0qbqwo9v/AAAdfglvn1RLAza4Y2mtG33Za?dl=1" #@param {type:"string"}
data_dir = "/content/project2-data" #@param {type:"string"}
data_dir = Path(data_dir)
# !rm -r $data_dir # clear out directory 
zip_name = "dataset.zip"
zip_name = Path(zip_name)
dataset = "pubmed_full" #@param ["pubmed_full", "pubmed_20k"]
dataset_already_here = zip_name.exists() and data_dir.exists()
if dataset_already_here:
    print("dataset files seem to exist already.. double check as needed")
else:
    session_params['dataset'] = dataset # log
    print(f'downloading {dataset}...')
    if dataset == "pubmed_20k":
        # download the 20k short dataset
        !wget $url_project2data_20k -O $zip_name -q
        !unzip -j -q $zip_name -d $data_dir
    else:
        # download the full dataset
        !wget $url_project2data_full -O $zip_name -q
        !unzip -j -q $zip_name -d $data_dir

    print("\n" * 3, f"files in the {data_dir} directory are:")
    !ls $data_dir

if dataset == "pubmed_20k":
    datafile_mapping = {
        "train":data_dir / 'train20.csv',
        "val":data_dir / 'dev20.csv',
        "test":data_dir / 'test20.csv',
    }
else:
    datafile_mapping = {
        "train":data_dir / 'train.csv',
        "val":data_dir / 'dev.csv',
        "test":data_dir / 'test.csv',
    }

downloading pubmed_full...
mapname:  conversion of  failed



 files in the /content/project2-data directory are:
dev.txt  test.txt  train.txt


In [None]:
# NOTE: this ^ may need some manual exceptions
do_lowercase = "uncased" in hf_tag.lower() or "albert" in hf_tag.lower()
print(f"lowercase={do_lowercase}")

if dataset_already_here:
    print(f"not re-processing existing dataset, validate that files exist as needed")
    print([f.resolve() for f in data_dir.iterdir() if f.is_file() and f.suffix == '.csv'])
else:
    proc_csv_paths = process_txt_data(data_dir, 
                                    lowercase=do_lowercase,
                                    verbose=True
                                    )
    csv_paths = {f.name:f for f in proc_csv_paths}
    dataset_names = list(datafile_mapping.values())
    found_CSV_names = list(datafile_mapping.values())
    error_msg = f"downloaded filenames {found_CSV_names} do not match dataset expected names, check links"
    assert all(elem in dataset_names for elem in found_CSV_names), error_msg

lowercase=True


  0%|          | 0/3 [00:00<?, ?it/s]

processed and returning:
	['dev.csv', 'test.csv', 'train.csv']


In [None]:
session_params["lowercased_input"] = do_lowercase
# note the below is a preview of VALIDATION and not train
example_df = pd.read_csv(datafile_mapping["val"])
example_df.head(5)

Unnamed: 0,target,description,description_cln
0,BACKGROUND,Adrenergic activation is thought to be an impo...,adrenergic activation is thought to be an impo...
1,RESULTS,Systemic venous norepinephrine was measured at...,systemic venous norepinephrine was measured at...
2,RESULTS,Baseline norepinephrine level was associated w...,baseline norepinephrine level was associated w...
3,RESULTS,"On multivariate analysis , baseline norepineph...","on multivariate analysis, baseline norepinephr..."
4,RESULTS,"In contrast , the relation of the change in no...","in contrast, the relation of the change in nor..."


In [None]:
# look at amount of chars dist
example_df["input_len"] = example_df.description_cln.apply(len)
example_df.describe()


Unnamed: 0,input_len
count,28932.0
mean,148.672784
std,75.503012
min,2.0
25%,95.0
50%,137.0
75%,187.0
max,862.0


In [None]:
example_df.target.value_counts() # check to ensure got rid of the bs hashtag labels

RESULTS        9977
METHODS        9559
CONCLUSIONS    4396
BACKGROUND     2575
OBJECTIVE      2425
Name: target, dtype: int64

In [None]:
#@markdown create `TextClassificationData.from_csv`

input_text_colname = "description_cln" #@param {type:"string"}
target_cls_colname = "target" #@param {type:"string"}
import torch

import flash
from flash.text import TextClassificationData

datamodule = TextClassificationData.from_csv(
    input_field=input_text_colname,
    target_fields=target_cls_colname,
    train_file=datafile_mapping["train"],
    val_file=datafile_mapping["val"],
    test_file=datafile_mapping["test"],
    batch_size=BATCH_SIZE,
)

session_params['input_text_colname'] = input_text_colname 
session_params['target_cls_colname'] = target_cls_colname 

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-65e0f92ad6495608/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-65e0f92ad6495608/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2211861 [00:00<?, ?ex/s]

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-387dd2a3b58511e7/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-387dd2a3b58511e7/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/28932 [00:00<?, ?ex/s]

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-a124811f2faef16b/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-a124811f2faef16b/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/29493 [00:00<?, ?ex/s]

## create model + helpers

for training

### metrics

**handling multiple possible classes at the same time**

https://lightning-flash.readthedocs.io/en/stable/reference/text_classification_multi_label.html

one-hot encoding

**how to use torch metrics** - [docs](https://torchmetrics.readthedocs.io/en/stable/pages/overview.html) | [article](https://www.exxactcorp.com/blog/Deep-Learning/advanced-pytorch-lightning-using-torchmetrics-and-lightning-flash)


In [None]:
from torchmetrics import AUROC, Accuracy, F1Score, MatthewsCorrCoef

_nc = datamodule.num_classes # alias
print(f"found number of classes as {_nc}")
logging.info(f"found number of classes as {_nc}")
acc = Accuracy(num_classes=_nc, average='weighted' if _nc > 2 else 'macro')
f1 = F1Score(num_classes=_nc, average='weighted' if _nc > 2 else 'macro')
mcc = MatthewsCorrCoef(num_classes=_nc, )
_metrics = [acc, mcc, f1]

found number of classes as 5


In [None]:
session_params["num_classes"] = _nc

### training logs

In [None]:
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

colname_detail = "" if input_text_colname == "description_cln" else f" {input_text_colname}" 
log_dir = root_dir / f"logs_{dataset}_{TRAIN_STRATEGY}{colname_detail}"
log_dir.mkdir(exist_ok=True)

# logger = CSVLogger(save_dir=str(log_dir.resolve())) #backup
log_dir_str = str(log_dir.resolve())
MODEL_BACKBONE = hf_tag.split('/')[-1] # parse "microsoft/BiomedNLP-PubMedBERT" with no /
MODEL_BACKBONE = MODEL_BACKBONE[:30] # max 30 chars
logger = TensorBoardLogger(
            save_dir=log_dir_str,
            name=f"txtcls_{dataset}_{MODEL_BACKBONE}"
            )

In [None]:
# log important hyperparameters for setup
session_params["model_shortname"] = MODEL_BACKBONE

logger.log_hyperparams(session_params)

### load model

- from huggingface


Q: what models can you actually use?

A:
> AutoModelForSequenceClassification.
Model type should be one of YosoConfig, NystromformerConfig, QDQBertConfig, FNetConfig, PerceiverConfig, GPTJConfig, LayoutLMv2Config, PLBartConfig, RemBertConfig, CanineConfig, RoFormerConfig, BigBirdPegasusConfig, GPTNeoConfig, BigBirdConfig, ConvBertConfig, LEDConfig, IBertConfig, MobileBertConfig, DistilBertConfig, AlbertConfig, CamembertConfig, XLMRobertaXLConfig, XLMRobertaConfig, MBartConfig, MegatronBertConfig, MPNetConfig, BartConfig, ReformerConfig, LongformerConfig, RobertaConfig, DebertaV2Config, DebertaConfig, FlaubertConfig, SqueezeBertConfig, BertConfig, OpenAIGPTConfig, GPT2Config, TransfoXLConfig, XLNetConfig, XLMConfig, CTRLConfig, ElectraConfig, FunnelConfig, LayoutLMConfig, TapasConfig, Data2VecTextConfig.

#### Learning Rate Schedule & Optimizer

- [docs](https://lightning-flash.readthedocs.io/en/stable/general/optimization.html) from lightning-flash



In [None]:
lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_f1score",
    "patience": 1,
    "min_lr": 1e-8,
    "reduce_on_plateau":True,
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

In [None]:
from flash.text import TextClassifier


logging.info("\n"*3)
logging.info(f"Loading new model: {hf_tag} for training on text classification")
model = TextClassifier(backbone=hf_tag, 
                        max_length=MAX_LEN,
                        labels=datamodule.labels,
                        metrics=_metrics,
                        learning_rate=LR_INITIAL,
                        optimizer=("Adam", {"amsgrad": True,
                                            "weight_decay":WEIGHT_DECAY}),
                        # lr_scheduler=LR_SCHEDULE,
                        lr_scheduler=("reducelronplateau", 
                                      {"mode": "max"}, 
                                      lr_scheduler_config,
                                    ),
                    )
model.hparams.batch_size = BATCH_SIZE

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/221k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Ber

In [None]:
model.configure_optimizers()  

([Adam (
  Parameter Group 0
      amsgrad: True
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.0001
      maximize: False
      weight_decay: 0.05
  )],
 [{'frequency': 1,
   'interval': 'epoch',
   'min_lr': 1e-08,
   'monitor': 'val_f1score',
   'name': None,
   'opt_idx': None,
   'patience': 1,
   'reduce_on_plateau': True,
   'scheduler': <torch.optim.lr_scheduler.ReduceLROnPlateau at 0x7fd7721b7710>,
   'strict': True}])

### create trainer

**deepspeed** 
- [docs](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#deepspeed) from PL
- [page](https://pytorch-lightning.readthedocs.io/en/latest/advanced/training_tricks.html#advanced-gpu-optimizations) on advanced GPU optimization


In [None]:
import pytorch_lightning
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.plugins import DeepSpeedPlugin
seed_everything(42, workers=True)

42

dir(pytorch_lightning.callbacks

In [None]:
from pytorch_lightning.callbacks import (
    StochasticWeightAveraging, 
    LearningRateMonitor,
    GPUStatsMonitor,
    EarlyStopping,
)

_callbacks = [
                StochasticWeightAveraging(), 
                LearningRateMonitor(),
                GPUStatsMonitor(),
                EarlyStopping(monitor='val_f1score',
                              mode='max', 
                              min_delta=0.003,
                              patience=2,
                            ),
            ]
trainer = flash.Trainer(
    max_epochs=NUM_EPOCHS,
    gpus=torch.cuda.device_count(),
    auto_lr_find=True,
    auto_scale_batch_size=True,
    precision=16 if TRAIN_FP16 else 32,
    callbacks=_callbacks,
    logger=logger,
)


# train

In [None]:
#@title tensorboard
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
%tensorboard --logdir $log_dir


## train trainer

In [None]:
logging.info(f"\t\tTRAINING:{hf_tag} ")

if TRAIN_STRATEGY == 'full_train':
    trainer.fit(
        model,
        datamodule=datamodule,
    )
else:
    trainer.finetune(
        model,
        datamodule=datamodule,
        strategy=("freeze_unfreeze", UNFREEZE_EPOCH) if TRAIN_STRATEGY =="freeze_unfreeze" \
                    else TRAIN_STRATEGY, # 'freeze_unfreeze' is a special case
        )

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

# evaluation / prediction

In [None]:
#@markdown set up monitoring for test set completion
!pip install knockknock -q
from knockknock import telegram_sender

CHAT_ID: int = 1458397289  # this means it will send the chat to peter szemraj

BOT_API: str = '2023363925:AAEabaBw8Xwka0HwBqV805ueU4ZicU4bO5o'
@telegram_sender(token=BOT_API, chat_id=CHAT_ID)
def knockknock_test_wrap(verbose=False):
    
    eval_output = trainer.test(verbose=verbose, datamodule=datamodule,)

    return eval_output

[K     |████████████████████████████████| 1.4 MB 26.7 MB/s 
[K     |████████████████████████████████| 497 kB 64.2 MB/s 
[K     |████████████████████████████████| 43 kB 2.6 MB/s 
[K     |████████████████████████████████| 48 kB 6.1 MB/s 
[K     |████████████████████████████████| 4.0 MB 57.4 MB/s 
[K     |████████████████████████████████| 404 kB 68.4 MB/s 
[K     |████████████████████████████████| 58 kB 7.7 MB/s 
[K     |████████████████████████████████| 428 kB 51.4 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires tornado~=5.1.0; python_version >= "3.0", but you have tornado 6.1 which is incompatible.[0m
[?25h

## test set

In [None]:
eval_output = knockknock_test_wrap(verbose=True)

In [None]:
# uncomment below if removing knockknock_test_wrap()
# trainer.test(verbose=True, datamodule=datamodule,)

In [None]:
# 4. Classify a few sentences
import pprint as pp
test_df = pd.read_csv(datafile_mapping["test"]).convert_dtypes()
predict_examples = test_df.sample(n=5)
sample_text = predict_examples[input_text_colname].to_list()
sample_ytrue = predict_examples[target_cls_colname].to_list()

datamodule = TextClassificationData.from_lists(
    predict_data=sample_text,
    batch_size=8,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")[0]

pp.pprint(dict(zip(sample_text, sample_ytrue)))

In [None]:
logging.info(f"\nmodel {hf_tag} had the following ytrue:ypred split:\n")
logging.info(f"\nYtrue:\t{sample_ytrue}\nYpred:\t{predictions} ")
print(f"\nYtrue:\t{sample_ytrue}\nYpred:\t{predictions}")


Ytrue:	['METHODS', 'METHODS', 'CONCLUSIONS', 'RESULTS', 'METHODS']
Ypred:	['METHODS', 'METHODS', 'RESULTS', 'RESULTS', 'METHODS']


## extract key metrics

In [None]:
import numpy
import pprint as pp

final_metrics = trainer.logged_metrics
output_metrics = {k:v.cpu().numpy().tolist() for k, v in final_metrics.items()}
output_metrics["date_run"] = get_timestamp()
output_metrics["huggingface_tag"] = hf_tag
output_metrics['results'] = eval_output
log_str = f"\t\t\t ======= Final results for {hf_tag} ======="
logging.info(log_str.upper())
logging.info(output_metrics)
pp.pprint(output_metrics)

{'date_run': 'Apr-26-2022_t-19',
 'huggingface_tag': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext',
 'results': [{'test_accuracy': 0.8025633096694946,
              'test_cross_entropy': 0.5453578233718872,
              'test_f1score': 0.8019025325775146,
              'test_matthewscorrcoef': 0.7307071685791016}]}


# save & export

In [None]:
import gc
gc.collect()

123

In [None]:
out_dir = root_dir / "model-checkpoints"
out_dir.mkdir(exist_ok=True)
m_name = hf_tag.split('/')[-1]
_chk_dir = f"{dataset}={m_name}_{get_timestamp()}"
session_dir = out_dir / _chk_dir
session_dir.mkdir(exist_ok=True)

_chk_name = f"textclassifer_{m_name}_{dataset}.pt"
model_out_path = session_dir / _chk_name
trainer.save_checkpoint(model_out_path.resolve())

In [41]:
import json
metrics_out_path = session_dir / "training_metrics.json"
with open(metrics_out_path, "w") as fp:
    json.dump(output_metrics, fp)

params_out_path = session_dir / "training_parameters.json"
with open(params_out_path, "w") as fp:
    json.dump(session_params, fp)

In [42]:
import shutil

shutil.copy(_das_logfile, session_dir / "training_session_toplevel_log.log")

PosixPath('/content/drive/MyDrive/ETHZ-2022-S/ML-healthcare-projects/project2/transformers/model-checkpoints/pubmed_full=BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext_Apr-26-2022_t-18/training_session_toplevel_log.log')

# print log

In [None]:
# !cat $_das_logfile