In [3]:
%load_ext autoreload
%autoreload 2
!git clone --branch=add-data https://github.com/schmidt-jake/kaggle.git
%cd kaggle
!pip install -U pip setuptools wheel
!pip install -U -r /kaggle/working/kaggle/requirements/main.in

Cloning into 'kaggle'...
remote: Enumerating objects: 216, done.[K
remote: Counting objects: 100% (216/216), done.[K
remote: Compressing objects: 100% (119/119), done.[K
remote: Total 216 (delta 99), reused 177 (delta 67), pack-reused 0[K
Receiving objects: 100% (216/216), 39.41 KiB | 733.00 KiB/s, done.
Resolving deltas: 100% (99/99), done.
/kaggle/working/kaggle
Collecting pip
  Downloading pip-22.1.2-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
Collecting setuptools
  Downloading setuptools-63.2.0-py3-none-any.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m44.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: setuptools, pip
  Attempting uninstall: setuptools
    Found existing installation: setuptools 59.8.0
    Uninstalling setuptools-59.8.0:
      Successfully uninstalled setuptools-59.8.0
  Attemp

In [None]:
!git pull

In [None]:
from glob import glob

import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy
from torchmetrics.classification import CalibrationError
from torchvision.models import densenet161

from mayo_clinic_strip_ai.dataset import TifDataset
from mayo_clinic_strip_ai.model import Classifier
from mayo_clinic_strip_ai.model import FeatureExtractor
from mayo_clinic_strip_ai.model import Loss
from mayo_clinic_strip_ai.model import Model
from mayo_clinic_strip_ai.model import Normalizer
from mayo_clinic_strip_ai.metadata import load_metadata
from multiprocessing import cpu_count


torch.backends.cudnn.benchmark = True

train_meta = pd.merge(
    left=pd.read_csv("/kaggle/input/mayo-rois/train/ROIs.csv"),
    right=load_metadata("/kaggle/input/mayo-clinic-strip-ai/train.csv"),
    how="left",
    validate="m:1",
    on="image_id",
)

model = Model(
    normalizer=Normalizer(),
    feature_extractor=FeatureExtractor(backbone=densenet161()),
    classifier=Classifier(initial_logit_bias=0.0, in_features=2208),  # FIXME: auto-set in_features
)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3, weight_decay=0.1)
loss_fn = Loss(pos_weight=0.5)  # FIXME: compute pos_weight from data
metrics = MetricCollection(Accuracy(), CalibrationError())

# move things to the right device
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
print("Device", device)
model.to(device=device, memory_format=torch.channels_last, non_blocking=True)
loss_fn.to(device=device, non_blocking=True)
metrics.to(device=device, non_blocking=True)

train_dataset = TifDataset(
    metadata=train_meta,
    training=True,
    data_dir="/kaggle/input/mayo-clinic-strip-ai/train/",
)
batch_size = 16
num_workers = cpu_count()
prefetch_batches = 4
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=torch.cuda.is_available(),
    pin_memory_device=str(device),
    prefetch_factor=max(2, prefetch_batches * batch_size // num_workers),
    num_workers=num_workers,
    drop_last=torch.backends.cudnn.benchmark,
)

grad_scaler = torch.cuda.amp.GradScaler()

print("num workers", num_workers)
print("prefetch factor", train_dataloader.prefetch_factor)
for epoch in range(10):
    print("Starting epoch", epoch)
    for img, label_id in train_dataloader:
        img = img.to(device=device, memory_format=torch.channels_last, non_blocking=True)
        label_id = label_id.to(device=device, non_blocking=True)
        with torch.autocast(device_type=img.device.type):
            optimizer.zero_grad(set_to_none=True)
            logit: torch.Tensor = model(img)
            loss: torch.Tensor = loss_fn(logit=logit, label=label_id)
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
            metrics.update(preds=logit.sigmoid(), target=label_id)
            m = {k: v.item() for k, v in metrics.compute().items()}
            metrics.reset()
        print(m)

[autoreload of pkg_resources failed: Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/opt/conda/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/opt/conda/lib/python3.7/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/opt/conda/lib/python3.7/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 630, in _exec
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/opt/conda/lib/python3.7/site-packages/pkg_resources/__init__.py", line 74, in <module>
    from pkg_resources.extern.jaraco.text import (
ModuleNotFoundError: No module named 'pkg_resources.extern.jaraco'
]


Device cuda:0
num workers 2
prefetch factor 32
Starting epoch 0
{'Accuracy': 0.375, 'CalibrationError': 0.1881103515625}
{'Accuracy': 0.5, 'CalibrationError': 0.2452802211046219}
{'Accuracy': 0.625, 'CalibrationError': 0.1963958740234375}
{'Accuracy': 0.4375, 'CalibrationError': 0.5358295440673828}
{'Accuracy': 0.8125, 'CalibrationError': 0.18795537948608398}
{'Accuracy': 0.6875, 'CalibrationError': 0.2509284019470215}
{'Accuracy': 0.6875, 'CalibrationError': 0.20892858505249023}
{'Accuracy': 0.6875, 'CalibrationError': 0.17264682054519653}
{'Accuracy': 0.6875, 'CalibrationError': 0.17568445205688477}
{'Accuracy': 0.8125, 'CalibrationError': 0.07723760604858398}
{'Accuracy': 0.75, 'CalibrationError': 0.1724081039428711}
{'Accuracy': 0.6875, 'CalibrationError': 0.12208367884159088}
{'Accuracy': 0.75, 'CalibrationError': 0.21883773803710938}
{'Accuracy': 0.625, 'CalibrationError': 0.20199203491210938}
{'Accuracy': 0.625, 'CalibrationError': 0.1725006252527237}
