In [1]:
import pytorch_lightning as pl
import torch
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights, convnext_base, ConvNeXt_Base_Weights, convnext_small, ConvNeXt_Small_Weights, efficientnet_b4, EfficientNet_B4_Weights
import torch.nn as nn
import torch.optim as optim
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from typing import List, Dict, Optional
import pandas as pd
import numpy as np
import os

import albumentations as albu
from albumentations.pytorch import ToTensorV2
import random
import matplotlib.pyplot as plt

from pathlib import Path
import random
import cv2


In [2]:
import timm
import torch.nn.functional as F

class AttnMIL(nn.Module):
    def __init__(self, num_classes, pretrained: bool):
        # Input of 256
#         torch.Size([1, 64, 128, 128])
#         torch.Size([1, 64, 64, 64])
#         torch.Size([1, 128, 32, 32])
#         torch.Size([1, 256, 16, 16])
#         torch.Size([1, 512, 8, 8])
        super(AttnMIL, self).__init__()
        self.L = 500
        self.D = 128
        self.K = 1

#         self.feature_extractor_part1 = nn.Sequential(
#             nn.Conv2d(3, 20, kernel_size=5),
#             nn.ReLU(),
#             nn.MaxPool2d(2, stride=2),
#             nn.Conv2d(20, 50, kernel_size=5),
#             nn.ReLU(),
#             nn.MaxPool2d(2, stride=2)
#         )

        self.feature_extractor_part1 = timm.create_model('resnet34', pretrained=pretrained, features_only=True, out_indices=[4])
    
        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(512*8*8, self.L),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, num_classes),
        )

    def forward(self, x):
        #x = x.squeeze(0)

        H = self.feature_extractor_part1(x)[0]
        H = H.view(-1, 512*8*8)
        H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

        y = self.classifier(M)

        return y

In [3]:
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
)
from torchmetrics import MetricCollection


class CancerDetector(pl.LightningModule):
    def __init__(
        self,
        lr: float,
        gamma: float,
        model_name: str,
        batch_size: int,
        warmup_epochs: int = 4,
        num_classes: int = 5,
        init_weights: bool = True,
    ):
        super().__init__()
        # TODO Use model preprocessing function
        self.model = self._get_model(model_name, num_classes, init_weights)

        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = lr
        self.gamma = gamma
        self.warmup_epochs = warmup_epochs
        self.batch_size = batch_size

        self.save_hyperparameters()

        # Should we use micro average? Default is macro
        metrics = MetricCollection(
            [
                MulticlassAccuracy(num_classes),
                MulticlassF1Score(num_classes),
                MulticlassPrecision(num_classes),
                MulticlassRecall(num_classes),
            ]
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.valid_metrics = metrics.clone(prefix="val/")

        self.train_step_outputs = []
        self.validation_step_outputs = []

    def _get_model(self, model_name: str, num_classes: int, init_weights: bool):
        if model_name == "attnmil":
            model = AttnMIL(num_classes, init_weights)
        else:
            raise Exception(f"Unknown model name {model_name}")

        return model

    def forward(self, imgs: torch.Tensor):
        return self.model(imgs)

    def training_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        output = self(x)
        loss = self.loss_fn(output, y)

        self.train_metrics.update(output, y)
        self.train_step_outputs.append(loss.detach().item())

        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        output = self(x)
        loss = self.loss_fn(output, y)

        self.valid_metrics.update(output, y)
        self.validation_step_outputs.append(loss.detach().item())

        return loss

    def on_train_epoch_end(self):
        loss = np.mean(self.train_step_outputs)
        self.log("train/loss", loss, on_step=False, on_epoch=True)

        output = self.train_metrics.compute()
        self.log_dict(output)

        self.train_metrics.reset()
        self.train_step_outputs.clear()

    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking:
            loss = np.mean(self.validation_step_outputs)
            self.log("val/loss", loss, on_step=False, on_epoch=True)

            output = self.valid_metrics.compute()
            self.log_dict(output)

        self.valid_metrics.reset()
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr=self.lr)

        warmup = optim.lr_scheduler.LinearLR(optimizer, total_iters=self.warmup_epochs)
        exponential = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.gamma)
        scheduler = optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup, exponential], milestones=[self.warmup_epochs]
        )

        return [optimizer], [scheduler]

In [4]:
best_model = "/kaggle/input/cancer-detection-w-attnmil/cancer_classification_model.pt"

# Load model from checkpoint
model = CancerDetector.load_from_checkpoint(best_model, init_weights=False)
model.eval()

CancerDetector(
  (model): AttnMIL(
    (feature_extractor_part1): FeatureListNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act1): ReLU(inplace=True)
          (aa): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act2): ReLU(inplace=True)
        )
        (1): BasicBl

In [5]:
x = torch.randn(350, 3, 256, 256, requires_grad=True)

# Export the model
model.to_onnx("attn_mil.onnx", 
                  x,
                  opset_version=17,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={'input' : {0 : 'bag_size'}})#,    # variable length axes
                                #'output' : {0 : 'batch_size'}})

verbose: False, log level: Level.ERROR



In [6]:
!pip install onnxruntime

Collecting onnxruntime
  Obtaining dependency information for onnxruntime from https://files.pythonhosted.org/packages/7a/cf/6aa8c56fd63f53c2c485921e411269c7b501a2b4e634bd02f226ab2d5d8e/onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
Downloading onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [7]:
!pip download onnxruntime

Collecting onnxruntime
  Obtaining dependency information for onnxruntime from https://files.pythonhosted.org/packages/7a/cf/6aa8c56fd63f53c2c485921e411269c7b501a2b4e634bd02f226ab2d5d8e/onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Using cached onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Collecting coloredlogs (from onnxruntime)
  Using cached coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
Collecting flatbuffers (from onnxruntime)
  Obtaining dependency information for flatbuffers from https://files.pythonhosted.org/packages/6f/12/d5c79ee252793ffe845d58a913197bfa02ae9a0b5c9bc3dc4b58d477b9e7/flatbuffers-23.5.26-py2.py3-none-any.whl.metadata
  Downloading flatbuffers-23.5.26-py2.py3-none-any.whl.metadata (850 bytes)
Collecting numpy>=1.21.6 (from onnxruntime)
  Obtaining dependency information for numpy>=1.21.6 from https://files.pythonhosted.org/packages/64/41/284783f1014685201e447ea

In [8]:
import onnxruntime

x = np.random.randn(100,3,256,256).astype(np.float32)

ort_session = onnxruntime.InferenceSession("attn_mil.onnx", providers=["CPUExecutionProvider"])
input_name = ort_session.get_inputs()[0].name

ort_outs = ort_session.run(None, {input_name: x})