### Importing relevant libraries

In [None]:
!pip install transformers
!pip install evaluate
!pip install transforms

!pip install "ray[tune]" scipy sklearn

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import json
import glob 
import itertools
from PIL import Image
from transformers import (
    AutoImageProcessor, 
    ViTForImageClassification, 
    SwinForImageClassification,
    TrainingArguments, 
    Trainer,
    ResNetModel,
    AutoTokenizer, 
    BertModel,
    BertPreTrainedModel,
    DefaultDataCollator,
    ViTFeatureExtractor,
    ViTMAEForPreTraining,
    ViTMAEConfig,
    ViTImageProcessor
)
from tqdm.auto import tqdm
from transformers.modeling_outputs import SequenceClassifierOutput
import evaluate
from datasets import load_dataset
import requests
from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.model_selection import cross_val_predict, train_test_split
# import scikitplot as skplt
import pandas as pd
from transformers import pipeline

### Training helper functions

In [None]:
# A function to see the size and # of params of a model - taken from class examples
def get_model_info(model):
    # Compute number of trainable parameters in the model
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Compute the size of the model in MB
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        
    size_all_mb = (param_size + buffer_size) / 1024**2
    
    return num_params, size_all_mb

# Data collator - form a batch by using a list of dataset elements as input
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

### Visualizing output from pretrained ViTMAE

In [None]:
# from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/ViTMAE/ViT_MAE_visualization_demo.ipynb

feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
imagenet_mean = np.array(feature_extractor.image_mean)
imagenet_std = np.array(feature_extractor.image_std)

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    converted_img = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()
    plt.imshow(converted_img)
    plt.title(title, fontsize=16)
    plt.axis('off')
    return 

def visualize_single_image(pixel_values):
    x = torch.einsum('nchw->nhwc', pixel_values)
    show_image(x[0], "original")

def visualize(pixel_values, model):
    # forward pass
    outputs = model(pixel_values)
    y = model.unpatchify(outputs.logits)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()
    
    # visualize the mask
    mask = outputs.mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', pixel_values)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()

# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)

model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")

#### Specific image example
url = "https://datasets-server.huggingface.co/assets/keremberke/chest-xray-classification/--/full/train/2/image/image.jpg"
image = Image.open(requests.get(url, stream=True).raw)
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
visualize(pixel_values, model)

### Dataset preprocessing

In [None]:
cudnn.benchmark = True
plt.ion()   # interactive mode
nih_dataset = False
model_name_or_path = (
    "facebook/vit-mae-base"
)
data_dir_name = "./vit-finetune"

# loading dataset
if nih_dataset:
  ds = load_dataset("alkzar90/NIH-Chest-X-ray-dataset", name="image-classification")
  labels = ds["train"].features["labels"].feature.names #nih dataset
  test_ds = "test"
else:
  ds = load_dataset("keremberke/chest-xray-classification", name="full")
  labels = ds["train"].features["labels"].names #chest xray dataset
  test_ds = "validation"
ds = ds.with_format("torch")

#loading image_processor
image_processor = ViTImageProcessor.from_pretrained(model_name_or_path, padding=True) #gives normalize func error

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# defining a custom weighted loss for my imbalanced dataset                                   - see https://huggingface.co/docs/transformers/main_classes/trainer?highlight=trainer#trainer
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")

        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (data biases positive label, weight 0 label more)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([2.0, 1.0]).to(device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

#defining transforms
size = (
    image_processor.size["shortest_edge"] if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
normalize = transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

# includes data augmentation
train_transforms = transforms.Compose(
        [
            transforms.RandomResizedCrop(size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    )

#no data augmentation
val_transforms = transforms.Compose(
        [
            transforms.Resize(size),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

ds['train'].set_transform(preprocess_train)
ds['validation'].set_transform(preprocess_val)

### Hyperparameter search with Ray Tune

In [None]:
!pip install ray[tune]
from ray import tune
from ray.tune import CLIReporter, ResultGrid
from ray.tune.examples.pbt_transformers.utils import (
    download_data,
    build_compute_metrics_fn,
)
from ray.tune.schedulers import PopulationBasedTraining
from tqdm.auto import tqdm
from transformers import (
    AutoConfig
)

smoke_test = False
samples = 20
gpus_per_trial = 1

task_name = "rte"
task_data_dir = data_dir_name+"-"+task_name
num_labels = len(labels)

# init config object
config = AutoConfig.from_pretrained(
    model_name_or_path, num_labels=num_labels, finetuning_task=task_name
)

# Triggers pre-trained model download to cache
ViTForImageClassification.from_pretrained(
    model_name_or_path,
    config=config,
)

# define model initialization function for trainer
def model_init():
    model = ViTForImageClassification.from_pretrained(
        model_name_or_path,
        config=config,
    )
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    
    # Freeze the entire model:
    for p in model.parameters():
      p.requires_grad = False
    
    # Turn back on the classifier weights
    for p in model.classifier.parameters():
      p.requires_grad=True  

    return model

# define training arguments
training_args = TrainingArguments(
  output_dir=task_data_dir,
  per_device_train_batch_size=16,
  per_device_eval_batch_size=32,
  evaluation_strategy="epoch", 
  num_train_epochs=4,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=1e-5,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  # report_to='tensorboard',
  load_best_model_at_end=True,
  do_train=True,
  do_eval=True,
  # max_steps=-1,
  weight_decay=0.1,
  logging_dir="./logs",
  skip_memory_metrics=True,
  report_to="none",

  #from class
  save_strategy="epoch",
  lr_scheduler_type="cosine",
  dataloader_num_workers=0,
)

# create tune config object
tune_config = {
    "per_device_train_batch_size": 16,
    "per_device_eval_batch_size": 16,
    "num_train_epochs": tune.choice([2, 3, 4]),
    "max_steps": 1 if smoke_test else 200,  # Used for smoke test.
}

# initialize scheduler with args
scheduler = PopulationBasedTraining(
    time_attr="training_iteration",
    metric="eval_acc",
    mode="max",
    perturbation_interval=1,
    hyperparam_mutations={
        "weight_decay": tune.uniform(0.0, 0.3),
        "learning_rate": tune.uniform(1e-5, 5e-5),
        "per_device_train_batch_size": tune.choice([8, 16, 32]),
        "per_device_eval_batch_size": tune.choice([8, 16, 32]),
        "lr_scheduler_type": tune.choice(["linear", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup"]),
    },
)

# initialize report with args, define what info what will be reported
reporter = CLIReporter(
    parameter_columns={
        "weight_decay": "w_decay",
        "learning_rate": "lr",
        "per_device_train_batch_size": "train_bs/gpu",
        "num_train_epochs": "num_epochs",
        "lr_scheduler_type": "lr_scheduler",
    },
    metric_columns=["eval_acc", "eval_loss", "epoch", "training_iteration"],
)

# Create the trainer
trainer = Trainer(
    model=None,
    model_init=model_init,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=build_compute_metrics_fn(task_name),
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    tokenizer=image_processor,
)

# launch hyperparameter search
trainer.hyperparameter_search(
    hp_space=lambda _: tune_config,
    backend="ray",
    n_trials=samples,
    resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
    scheduler=scheduler,
    keep_checkpoints_num=1,
    checkpoint_score_attr="training_iteration",
    stop={"training_iteration": 1} if smoke_test else None,
    progress_reporter=reporter,
    local_dir="./ray_results/",
    name="tune_transformer_pbt",
    log_to_file=True,
)

### Fine-tuning on custom dataset using best hyperparameter results

In [None]:
model = ViTForImageClassification.from_pretrained(
    "facebook/vit-mae-base",
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True,
    # problem_type="multi_label_classification", #for nih dataset
)
model.to(device)

# Print model info
num_params, size_all_mb = get_model_info(model)

# Freeze the entire model:
for p in model.parameters():
    p.requires_grad = False
    
# Turn back on the classifier weights
for p in model.classifier.parameters():
    p.requires_grad=True

# Setup the training arguments
output_dir = "./finetune_vit"

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    lr_scheduler_type="cosine",
    logging_steps=10,
    save_total_limit=2,
    remove_unused_columns=False, #we need the unused features ('image' in particular) in order to create 'pixel_values'
    push_to_hub=False,
    load_best_model_at_end=True,
    dataloader_num_workers=0,  
#     gradient_accumulation_steps=8,
)

# Compute absolute learning rate
base_learning_rate = 1e-3
total_train_batch_size = (
    training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)

training_args.learning_rate = base_learning_rate * total_train_batch_size / 256
print("Set learning rate to:", training_args.learning_rate)

# Setup a function to compute accuracy metrics
def compute_metrics(eval_pred):
    metric1 = evaluate.load("precision")
    metric2 = evaluate.load("recall")
    metric3 = evaluate.load("accuracy")
    metric4 = evaluate.load("f1")

    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    precision = metric1.compute(predictions=predictions, references=labels)["precision"]
    recall = metric2.compute(predictions=predictions, references=labels)["recall"]
    accuracy = metric3.compute(predictions=np.argmax(eval_pred.predictions, axis=1), references=eval_pred.label_ids)["accuracy"]
    f1_score = metric4.compute(predictions=predictions, references=labels)["f1"]

    return {"precision": precision, "recall": recall, "accuracy": accuracy, "f1_score": f1_score}

# Create the trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=ds['train'],
    eval_dataset=ds[test_ds], 
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    # data_collator=DefaultDataCollator(),
)

# Train
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

# Inference
predictions = trainer.predict(ds[test_ds])

# Evaluation
metrics = trainer.evaluate(ds[test_ds])
trainer.log_metrics("eval", metrics)

### Inference & Visualization

In [None]:
# !pip install scikit-plot==0.3.7
from transformers import pipeline
from sklearn import datasets, metrics
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.model_selection import cross_val_predict, train_test_split
import scikitplot as skplt
import pandas as pd
import warnings
from transformers.pipelines.pt_utils import KeyDataset
from sklearn.metrics import roc_curve, roc_auc_score

ds = load_dataset("keremberke/chest-xray-classification", name="full")

# Inference using the model on unseen test data
ds['test'].set_transform(preprocess_val)
predictions = trainer.predict(ds['test'])

# metrics
print("Accuracy:", predictions.metrics["test_accuracy"])
print("Precision:", predictions.metrics["test_precision"])
print("Recall:", predictions.metrics["test_recall"])
print("F1-Score:", predictions.metrics["test_f1_score"])

# Probability scores
y_probs = torch.nn.functional.softmax(torch.Tensor(predictions[0]), dim=-1) #retrieving probabilities that would be returned by self.classifier

# Ground truth
y_test = ds['test'].with_format("torch")['labels']

# Probability of the class with the greater label & predicted label ids
y_max_probs, y_preds = torch.max(y_probs,1)

##### Visualization
plt.rc('font', size=12) # controls default text sizes

print("roc_auc_score:", metrics.roc_auc_score(y_test, y_max_probs))

#visualizing class distribution of training data
plt.bar(["NORMAL", "PNEUMONIA"], [len(ds['train']['labels'])-sum(ds['train']['labels']), sum(ds['train']['labels'])])
plt.title("Class distribution of training data")

# plotting ROC curve
skplt.metrics.plot_roc(y_test, y_probs, title = 'ROC Plot for Chest X-Ray dataset')
plt.show()

#plotting precision, recall across classification thresholds
skplt.metrics.plot_precision_recall(y_test, y_probs, title = 'PR Curve for Chest X-Ray dataset')
plt.show()

#plotting cumulative gain
skplt.metrics.plot_cumulative_gain(y_test, y_probs, title = 'Cumulative Gains Chart for Chest X-Ray dataset')
plt.show()

# plotting confusion matrix
# skplt.metrics.plot_confusion_matrix(y_test, y_preds, normalize=False, title = 'Normalized Confusion Matrix for Chest X-Ray dataset')
# plt.show()
