In [79]:
import numpy as np
from sklearn.metrics import f1_score, roc_auc_score
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

from datasets.rsna_breast_cancer import BreastCancerDataset
from src import backbones
from torch import optim, nn

In [80]:
device = 'mps'
backbone = backbones.load("resnet50")
model = nn.Sequential(backbone, torch.nn.Linear(1000, 1, bias=False))
model.to(device).train()



Sequential(
  (0): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0)

In [81]:
ds = BreastCancerDataset(
    img_dir="/Users/ksoll/Documents/git/SimpleNet/data/rsna_breast_cancer",
    meta_data_csv_path="/Users/ksoll/Documents/git/SimpleNet/train.csv",
    num_images=(1024, 0, 128, 0)
)

val = BreastCancerDataset(
    img_dir="/Users/ksoll/Documents/git/SimpleNet/data/rsna_breast_cancer",
    meta_data_csv_path="/Users/ksoll/Documents/git/SimpleNet/train.csv",
    num_images=(128, 1024, 128, 128)
)

dataloader = DataLoader(ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val, batch_size=64)

In [82]:
loss_fn = nn.CrossEntropyLoss()

In [83]:
optimizer = optim.Adam(model.parameters(), lr=0.5)

for epoch in range(10):
    for data in tqdm(dataloader):
        images = data['image'].to(device)
        labels = data['anomaly'].to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
    
    preds = []
    labels_ = []
    for data in tqdm(val_loader):
        images = data['image'].to(device)
        labels = data['anomaly'].to(device)
        with torch.no_grad():
            output = model(images)
        preds.append(output)
        labels_.append(labels)
    preds = torch.cat(preds, dim=0).cpu().numpy()
    
    preds = (preds - preds.min()) / (preds.max() - preds.min())
    
    labels = torch.cat(labels_, dim=0).cpu().numpy()
    preds_bin = np.where(preds >= 0.5, 1, 0)
    f1 = f1_score(labels, preds_bin)
    auc = roc_auc_score(labels, preds)
    print(f1, auc)

100%|██████████| 18/18 [00:50<00:00,  2.78s/it]
100%|██████████| 4/4 [00:02<00:00,  1.51it/s]


0.5666666666666667 0.5059814453125


100%|██████████| 18/18 [00:45<00:00,  2.53s/it]
100%|██████████| 4/4 [00:02<00:00,  1.64it/s]


0.5666666666666667 0.5059814453125


 11%|█         | 2/18 [00:05<00:47,  2.99s/it]


KeyboardInterrupt: 