<a href="https://colab.research.google.com/github/pszemraj/ml4hc-s22-project01/blob/main/notebooks/colab/tabular_classification_LF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch Lightning-Flash: Test Various Tabular Classification models

_heavily modified / adapted from the [titanic classification tutorial in LF docs](https://github.com/PyTorchLightning/lightning-flash/blob/b208689ea693e1cb6ffecb301915b3b97618871a/flash_notebooks/tabular_classification.ipynb)_

---

  - [LF Github](https://www.github.com/PytorchLightning/pytorch-lightning/)
  - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)
  - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)

---

In [1]:
#@title print out GPU info
#@markdown this is the Colab-allocated GPU. If the output here says it fails, no
#@markdown GPU is being used. go to runtime at the top of your colab to set runtime to GPU.


!nvidia-smi

Sat Mar 26 02:42:30 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    29W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# setup

In [2]:
#@markdown add auto-Colab formatting with `IPython.display`
from IPython.display import HTML, display
# colab formatting
def set_css():
    display(
        HTML(
            """
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  """
        )
    )

get_ipython().events.register("pre_run_cell", set_css)

In [3]:
#@title mount drive, define root folder
from google.colab import drive
from pathlib import Path
drive_base_str = '/content/gdrive'
drive.mount(drive_base_str)


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [4]:
drive_head_dir = Path(drive_base_str)

root_dir = "/content/gdrive/MyDrive/ETHZ-2022-S/ML-healthcare-projects/project1/lightning-flash-models" #@param {type:"string"}
root_dir = Path(root_dir)
if not root_dir.exists():
    print(f"{root_dir.resolve()} does not exist, creating generic folder in drive root")
    root_dir = drive_head_dir / "lf-tabular-classifier"
    root_dir.mkdir(exist_ok=True)

In [5]:
#@title nn training parameters
import torch
NUM_EPOCHS =  50#@param {type:"integer"}
BATCH_SIZE = 128 #@param {type:"integer"}
VAL_SPLIT = 0.15 #@param {type:"number"}
TRAIN_FP16 = True #@param {type:"boolean"}
MODEL_BACKBONE = "tabnet" #@param ["tabnet", "autoint", "category_embedding", "fttransformer", "node", "tabtransformer"]

if not torch.cuda.is_available():
    print("cuda not available, setting var TRAIN_FP16 to False.")
    TRAIN_FP16=False

## install

In [6]:
# %%capture
! pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[tabular]' -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone


In [7]:
#@title define source data parameters

#@markdown - these can also be loaded from gdrive, but I am lazy and `wget` does not require login

mitbih_train_url = "https://www.dropbox.com/s/2ks8s82tm7jvhse/torchfmt_mitbih_train.csv?dl=1" #@param {type:"string"}
mitbih_train_filename = "mitbih_train.csv" #@param {type:"string"}
mitbih_test_url = "https://www.dropbox.com/s/nbaxenoehvqmqnm/torchfmt_mitbih_test.csv?dl=1" #@param {type:"string"}
mitbih_test_filename = "mitbih_test.csv" #@param {type:"string"}

In [8]:
from torchmetrics.classification import Accuracy, Precision, Recall

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassifier, TabularClassificationData

  import pandas.util.testing as tm


###  1. Download the data
The data are downloaded from a URL, and save in a 'data' directory.

In [9]:

!wget $mitbih_train_url -O $mitbih_train_filename
!wget $mitbih_test_url -O $mitbih_test_filename

--2022-03-26 02:42:57--  https://www.dropbox.com/s/2ks8s82tm7jvhse/torchfmt_mitbih_train.csv?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.4.18, 2620:100:601a:18::a27d:712
Connecting to www.dropbox.com (www.dropbox.com)|162.125.4.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/dl/2ks8s82tm7jvhse/torchfmt_mitbih_train.csv [following]
--2022-03-26 02:42:58--  https://www.dropbox.com/s/dl/2ks8s82tm7jvhse/torchfmt_mitbih_train.csv
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://ucd135e83a95c81e349d8d18b1b0.dl.dropboxusercontent.com/cd/0/get/BiKtJZd1Uk63FGhY09DyzUjucRNuokkv9OspmLCshn8hztnSgnXDX9m7_Gd30t352Jl9eZrFsXkNYXnck1FfqQErc2hjsX-SryXTYAgBcof5jcchzlOhkGyd17wqp9ILjdcyc-aybwpyV1S8ZQ-2R46XjMxnP81_jaxNPHksYAZczw/file?dl=1# [following]
--2022-03-26 02:42:58--  https://ucd135e83a95c81e349d8d18b1b0.dl.dropboxusercontent.com/cd/0/get/BiKtJZd1Uk63FGhY09DyzUjucRN

###  2. Load the data
Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.

- Creates a TabularData relies on [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html). 
- docs on the [TabularClassificationData](https://github.com/PyTorchLightning/lightning-flash/blob/6da53fe99b220edacf69ea1701ee082ce76ef184/flash/tabular/classification/data.py)

In [10]:
import pandas as pd
example_df = pd.read_csv(mitbih_train_filename)
data_cols = list(example_df.columns)
_target = data_cols[-1]
data_cols.pop()
_predictors = data_cols # all other columns are numerical predictors

print(f"the target colname is {_target} and\nthe predictor colnames 5 of {len(_predictors)} are {_predictors[:5]}")

the target colname is class_label and
the predictor colnames 5 of 187 are ['feat_0', 'feat_1', 'feat_2', 'feat_3', 'feat_4']


In [11]:
datamodule = TabularClassificationData.from_csv(
    numerical_fields=_predictors,
    target_fields=_target,
    train_file=mitbih_train_filename,
    test_file=mitbih_test_filename,
    val_split=VAL_SPLIT,
    batch_size=BATCH_SIZE,
)
print(f"found {datamodule.num_classes} classes in predict column")

found 5 classes in predict column




import metric objects

In [12]:
# metrics
import torchmetrics
metric_acc = Accuracy(datamodule.num_classes)
metric_f1 = torchmetrics.F1(datamodule.num_classes)
metric_CK = torchmetrics.CohenKappa(datamodule.num_classes)
metric_matthewscorr = torchmetrics.MatthewsCorrcoef(datamodule.num_classes)
metric_rocAUC = torchmetrics.AUROC(num_classes=datamodule.num_classes)
my_metrics = [
                metric_acc,
                metric_f1,
                metric_matthewscorr,
                metric_CK,
                metric_rocAUC,
] 


my_metrics2 = {
                "acc":metric_acc,
                "f1":metric_f1,
                "mcorr":metric_matthewscorr,
                "CK":metric_CK,
                "rocAUC":metric_rocAUC,
}


  stream(template_mgs % msg_args)
  stream(template_mgs % msg_args)


setup logging 

In [13]:
from pytorch_lightning.loggers import CSVLogger  # noqa: E402]

log_dir = root_dir / "logs"
log_dir.mkdir(exist_ok=True)

logger = CSVLogger(save_dir=str(log_dir.resolve()))

###  3. Build the model

Note: Categorical columns will be mapped to the embedding space. Embedding space is set of tensors to be trained associated to each categorical column. 

In [14]:
import pprint as pp
backbones = TabularClassifier.available_backbones()
print("available model backbones for tabular as follows:\n")
pp.pprint(backbones)

available model backbones for tabular as follows:

['autoint',
 'category_embedding',
 'fttransformer',
 'node',
 'tabnet',
 'tabtransformer']


In [15]:
my_metrics

[Accuracy(), F1(), MatthewsCorrcoef(), CohenKappa(), AUROC()]

- tab transformer `1368/1368 [00:19<00:00, 69.90it/s, loss=0.0379, v_num=1, valid_loss=0.134, valid_accuracy=0.980, train_loss=0.127, train_accuracy=0.994]`
- category_embedding ` loss=0.109, v_num=2, valid_loss=0.093, valid_accuracy=0.975, train_loss=0.131, train_accuracy=0.962]`

In [16]:
m_name = MODEL_BACKBONE
model = TabularClassifier.from_data( 
            datamodule,
            backbone=m_name,
            # metrics=my_metrics2,
            optimizer="Adam",
            lr_scheduler="constantlr",

        )

Using 'category_embedding' provided by manujosephv/PyTorch Tabular (https://github.com/manujosephv/pytorch_tabular).


# Training

###  4. Create the trainer

- uses key training params defined above

In [17]:
from pytorch_lightning.callbacks import StochasticWeightAveraging

trainer = flash.Trainer(
    max_epochs=NUM_EPOCHS,
    gpus=torch.cuda.device_count(),
    auto_lr_find=True,
    precision=16 if TRAIN_FP16 else 32,
    logger=logger,
    callbacks=[StochasticWeightAveraging()],
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


###  5. Train the model

In [18]:
trainer.fit(model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                  | Params
--------------------------------------------------------
0 | train_metrics | ModuleDict            | 0     
1 | val_metrics   | ModuleDict            | 0     
2 | test_metrics  | ModuleDict            | 0     
3 | adapter       | PytorchTabularAdapter | 34.9 K
--------------------------------------------------------
34.9 K    Trainable params
0         Non-trainable params
34.9 K    Total params
0.140     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Swapping scheduler <torch.optim.lr_scheduler.ConstantLR object at 0x7f409cbcce50> for <torch.optim.swa_utils.SWALR object at 0x7f4027da7b50>


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

###  7. Save it!

In [21]:
download_chkpt = False #@param {type:"boolean"}


In [22]:
from datetime import datetime
def get_timestamp():
    return datetime.now().strftime("%b-%d-%Y_t-%H")

In [23]:
_chk_name = f"tabcls_MIT_bb={m_name}_{get_timestamp()}.pt"
out_dir = root_dir / "model-checkpoints"
model_out_path = out_dir / _chk_name
trainer.save_checkpoint(model_out_path.resolve())

In [24]:
from google.colab import files

if download_chkpt: files.download(model_out_path)

###  6. Test model

In [25]:
trained_model = TabularClassifier.load_from_checkpoint(model_out_path)

Using 'category_embedding' provided by manujosephv/PyTorch Tabular (https://github.com/manujosephv/pytorch_tabular).


validation

In [None]:
# validate results
my_metrics = trainer.validate(
    model=trained_model,
    # ckpt_path="best",
    val_dataloaders=datamodule,
    verbose=True,
)

In [None]:
trainer.test(trained_model, datamodule=datamodule)