In [1]:
import warnings
warnings.filterwarnings("ignore")
import os
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import cv2
import warnings
import timeit
import pytorch_lightning as pl
from tqdm.auto import tqdm
from thop import profile
from efficientnet_pytorch import EfficientNet
from torch import optim
from scipy.special import expit
#from pytorch_lightning.loggers import WandbLogger
from skimage import io, transform, measure
from sklearn import metrics
from config import *
from util import *
from models import *
from models_ablations import *
sns.set()
warnings.filterwarnings("ignore")
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 500)
rc = {"figure.figsize" : (9, 7),
      "axes.spines.left" : False,
      "axes.spines.right" : False,
      "axes.spines.bottom" : False,
      "axes.spines.top" : False,
      "xtick.bottom" : True,
      "xtick.labelbottom" : True,
      "ytick.labelleft" : True,
      "ytick.left" : True,
      "axes.grid" : False}
plt.rcParams.update(rc)


def find_optimal_cutoff(target, predicted):
    fpr, tpr, threshold = metrics.roc_curve(target, predicted)
    i = np.arange(len(tpr))
    roc = pd.DataFrame({'tf': pd.Series(tpr - (1 - fpr), index=i), 'threshold': pd.Series(threshold, index=i)})
    roc_t = roc.iloc[(roc.tf - 0).abs().argsort()[:1]]
    return roc_t['threshold'].item()


def infer_diagnosis(result, mel_class_labels_pred):
    preds = [1 if (row[mel_class_labels_pred] == 1).sum() > 1 else 0 for _, row in result.iterrows()]
    result['prediction'] = preds
    return result

def get_thresholds(result, char_class_labels):
    thresholds = []
    for col in char_class_labels:
        threshold = find_optimal_cutoff(result[col], result[col+'_score'])
        thresholds.append(threshold)
    thresholds = torch.tensor(thresholds)
    return thresholds

def display_scores(result):
    print('balanced acc: ', metrics.balanced_accuracy_score(result['benign_malignant'], result['prediction']).round(5))
    print('sensitivity: ', metrics.recall_score(result['benign_malignant'], result['prediction']).round(5))
    print('specificity: ', metrics.recall_score(result['benign_malignant'], result['prediction'], pos_label=0).round(5))

In [3]:
%%time
torch.cuda.empty_cache()
model = Resnet50(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                                  batch_size=batch_size, learning_rate=learning_rate)
trainer = pl.Trainer(max_epochs=1, devices=1, accelerator="gpu", deterministic=True)
trainer.fit(model)

# Calculate FLOPs and parameters
sample_input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(sample_input,))
print(f"Number of FLOPs for {model_name}: {flops}")
print(f"Number of parameters for {model_name}: {params}")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name        | Type              | Params
--------------------------------------------------
0 | lossC       | BCEWithLogitsLoss | 0     
1 | lossA       | DiceLoss          | 0     
2 | base_model  | ResNet            | 23.5 M
3 | sigmoid     | Sigmoid           | 0     
4 | accuracy    | Accuracy          | 0     
5 | auroc       | AUROC             | 0     
6 | sensitivity | Recall            | 0     
7 | specificity | Specificity       | 0     
--------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.114    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [12]:
result_val = get_char_predictions(trainer, model, split='val', threshold=-0.3)
result_val = infer_diagnosis(result_val, mel_class_labels_pred)
thresholds = get_thresholds(result_val, char_class_labels)

print('Validation:')
display_scores(result_val)

result_test = get_char_predictions(trainer, model, split='test', threshold=thresholds)
result_test = infer_diagnosis(result_test, mel_class_labels_pred)
print('Test:')
display_scores(result_test)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting: 83it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Validation:
balanced acc:  0.70256
sensitivity:  0.50581
specificity:  0.8993


Predicting: 83it [00:00, ?it/s]

Test:
balanced acc:  0.67
sensitivity:  0.68
specificity:  0.66
balanced acc:  0.67
sensitivity:  0.68
specificity:  0.66


In [6]:
result_val.to_csv('data/classifier/result_val_resnet101', index=False)
result_test.to_csv('data/classifier/result_test_resnet101', index=False)

In [2]:
def get_model(model_name):
    seed_everything(seed)
    if model_name == 'resnet18':
        return Resnet18(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                        batch_size=batch_size, learning_rate=learning_rate)
    elif model_name == 'resnet34':
        return Resnet34(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                        batch_size=batch_size, learning_rate=learning_rate)
    elif model_name == 'resnet50': 
        return Resnet50(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                        batch_size=batch_size, learning_rate=learning_rate)
    elif model_name == 'resnet101': 
        return Resnet101(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                         batch_size=batch_size, learning_rate=learning_rate)
    elif model_name == 'densenet121': 
        return Densenet121(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                           batch_size=batch_size, learning_rate=learning_rate)
    elif model_name == 'densenet161': 
        return Densenet161(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                           batch_size=batch_size, learning_rate=learning_rate)
    elif model_name == 'efficientnetb1':
        return EfficientnetB1(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                              batch_size=batch_size, learning_rate=learning_rate)
    elif model_name == 'efficientnetb3': 
        return EfficientnetB3(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                              batch_size=batch_size, learning_rate=learning_rate)

In [3]:
%%time

model_names_list = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'densenet121', 'densenet161', 'efficientnetb1', 'efficientnetb3']

for model_name in model_names_list:
    
    torch.cuda.empty_cache()
    print(model_name)
    seed_everything(seed)
    model = get_model(model_name)
    trainer = pl.Trainer(max_epochs=num_epochs, devices=1, accelerator="gpu", deterministic=True, enable_checkpointing=False, logger=False)
    trainer.fit(model)
    trainer.save_checkpoint("models/"+model_name+".ckpt")
    
    
    result_val = get_char_predictions(trainer, model, split='val', threshold=-0.3)
    result_val = infer_diagnosis(result_val, mel_class_labels_pred)
    thresholds = get_thresholds(result_val, char_class_labels)

    print('Validation:')
    display_scores(result_val)

    result_test = get_char_predictions(trainer, model, split='test', threshold=thresholds)
    result_test = infer_diagnosis(result_test, mel_class_labels_pred)
    print('Test:')
    display_scores(result_test)
    
    
    result_val.to_csv('data/classifier/result_val_'+model_name, index=False)
    result_test.to_csv('data/classifier/result_test_'+model_name, index=False)
    
    # Calculate FLOPs and parameters
    sample_input = torch.randn(1, 3, 224, 224)
    flops, params = profile(model, inputs=(sample_input,))
    print(f"Number of FLOPs for {model_name}: {flops}")
    print(f"Number of parameters for {model_name}: {params}")

resnet18


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name        | Type              | Params
--------------------------------------------------
0 | lossC       | BCEWithLogitsLoss | 0     
1 | lossA       | DiceLoss          | 0     
2 | base_model  | ResNet            | 11.2 M
3 | sigmoid     | Sigmoid           | 0     
4 | accuracy    | Accuracy          | 0     
5 | auroc       | AUROC             | 0     
6 | sensitivity | Recall            | 0     
7 | specificity | Specificity       | 0     
--------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting: 83it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Validation:
balanced acc:  0.77644
sensitivity:  0.80814
specificity:  0.74473


Predicting: 83it [00:00, ?it/s]

Test:
balanced acc:  0.775
sensitivity:  0.74
specificity:  0.81
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn