In [1]:
import numpy as np
import pandas as pd
import os
import torch
from torch import nn, optim
from sklearn.model_selection import StratifiedKFold

from work.utils.dataset import PandasDataset
from work.utils.dataset import RemovePenMarkAlbumentations
from warmup_scheduler import GradualWarmupScheduler
import albumentations as A

from torch.utils.data import DataLoader
from work.utils.models import EfficientNet
from work.utils.train import apply_active_learning, remove_images_by_entropy

In [2]:
backbone_model = 'efficientnet-b0'
pretrained_model = {
    backbone_model: 'pre-trained-models/efficientnet-b0-08094119.pth'
}
data_dir = 'data'
images_dir = os.path.join(data_dir, 'tiles')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
df = pd.read_csv(f"{data_dir}/train_val.csv")
df.head()

Unnamed: 0,image_id,data_provider,isup_grade,gleason_score
0,aa9be7d9f82e983d21e2746078b877d9,radboud,4,4+4
1,34a98ca2d4eb1a91e428bf2112e26543,karolinska,1,3+3
2,95eeb46ecc4a9693119627fedb8df55c,radboud,4,4+4
3,1df32b02eaa3cfad5d8c51a3e289cfc1,radboud,1,3+3
4,ebb6d5ca45942536f78beb451ee43cc4,radboud,2,3+4


In [4]:
batch_size = 2
num_workers = 4
output_classes = 5
init_lr = 3e-4
loss_function = nn.BCEWithLogitsLoss()
epochs = 10
n_folds = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
df.columns = df.columns.str.strip()

stratified_k_fold = StratifiedKFold(n_folds, shuffle=True, random_state=42)

df['fold'] = -1

for i, (train_indexes, valid_indexes) in enumerate(stratified_k_fold.split(df, df['isup_grade'])):
    df.loc[valid_indexes, 'fold'] = i

df.head()

Unnamed: 0,image_id,data_provider,isup_grade,gleason_score,fold
0,aa9be7d9f82e983d21e2746078b877d9,radboud,4,4+4,0
1,34a98ca2d4eb1a91e428bf2112e26543,karolinska,1,3+3,3
2,95eeb46ecc4a9693119627fedb8df55c,radboud,4,4+4,1
3,1df32b02eaa3cfad5d8c51a3e289cfc1,radboud,1,3+3,0
4,ebb6d5ca45942536f78beb451ee43cc4,radboud,2,3+4,4


In [6]:
transforms = A.Compose([
    RemovePenMarkAlbumentations(),
    A.Transpose(p=0.5),
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
])

In [7]:
images = {}
for fold in range(n_folds):
    train_indexes = np.where((df["fold"]!=fold))[0]
    valid_indexes = np.where((df["fold"] == fold))[0]

    df_train = df.loc[train_indexes]
    df_val = df.loc[valid_indexes]

    dataset_train = PandasDataset("../dataset/tiles", df_train, transforms=transforms)
    dataset_valid = PandasDataset("../dataset/tiles", df_val)

    print(f"fold: {fold+1} train: {len(dataset_train)} images | validation: {len(dataset_valid)} images ")

    train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    valid_loader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    model = EfficientNet(backbone_model, output_classes, weights_path=pretrained_model.get(backbone_model))
    optimizer = optim.Adam(model.parameters(), lr=init_lr, weight_decay=1e-2)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs - 1)

    scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch = 1, after_scheduler=scheduler_cosine)

    apply_active_learning(
        model,
        epochs=epochs,
        optimizer=optimizer,
        scheduler=scheduler,
        train_dataloader=train_loader,
        device=device,
        loss_function=loss_function
    )

    local_images:dict = remove_images_by_entropy(model, valid_loader, device)
    images = {**images, **local_images}
    print("images", images)

# df_remove_images = pd.DataFrame(data={"image_id": images})
# df_remove_images.head()


fold: 1 train: 7219 images | validation: 1805 images 
Loaded pretrained weights for efficientnet-b0


loss: 0.37412, smooth loss: 0.56673: 100%|██████████| 3610/3610 [27:08<00:00,  2.22it/s]
loss: 0.42867, smooth loss: 0.59254: 100%|██████████| 3610/3610 [26:48<00:00,  2.24it/s]
loss: 0.40392, smooth loss: 0.56744: 100%|██████████| 3610/3610 [26:37<00:00,  2.26it/s]
loss: 1.02030, smooth loss: 0.55751: 100%|██████████| 3610/3610 [26:42<00:00,  2.25it/s]
loss: 0.40656, smooth loss: 0.56112: 100%|██████████| 3610/3610 [26:44<00:00,  2.25it/s]
loss: 0.40762, smooth loss: 0.57237: 100%|██████████| 3610/3610 [26:45<00:00,  2.25it/s]
loss: 0.56062, smooth loss: 0.56967: 100%|██████████| 3610/3610 [26:45<00:00,  2.25it/s]
loss: 0.56322, smooth loss: 0.56843: 100%|██████████| 3610/3610 [26:46<00:00,  2.25it/s]
loss: 0.56220, smooth loss: 0.55096: 100%|██████████| 3610/3610 [26:46<00:00,  2.25it/s]
loss: 0.56425, smooth loss: 0.59203: 100%|██████████| 3610/3610 [26:46<00:00,  2.25it/s]
100%|██████████| 903/903 [02:08<00:00,  7.03it/s]


images {'991b6a11393bf5d10e91a50f581cba68': 1.3184186, '7a5f4e79d1efd7b32518909b00753bd5': 1.3184186, '506eec1a2c0a806ef10ff76a9934cfe0': 1.3184186, '0470680842880ac7493bbf5e7cfbc26b': 1.3184186, '93b3300fd244ac326a9f11a6a47ed016': 1.3184186, 'dd26b0dd91cb20493617a9b82ab9d8fd': 1.3184186, 'b272895e4dc95112031833596f94621a': 1.3184186, 'ac0d8bc88d9ec900be978a70b90e425b': 1.3184186, '65b25c89a74ee86c54185060a6bfc25b': 1.3184186, '8364aaf55439cd2b1207b7bfa35d79a2': 1.3184186, 'ab7f757a2343686945a621a662b6f667': 1.3184186, '8dbbf6f6aebb8aea3339f70dbfd7cf07': 1.3184186, '48896f2b866e9694586b3377025c4927': 1.3184186, 'd49ec86c92a9a403ccd971a0e12fe734': 1.3184186, 'f709898cdc46cedb62cef65b87039218': 1.3184186, '331df8bf3a7f766dfe45ed525040fa85': 1.3184186, 'e7ded102214a944d051a2628d5452faa': 1.3184186, '8addcf68a92fde8dba0b2a99caf9bff9': 1.3184186, 'e2876d9ebfde8e9b19701ac52b33dcd4': 1.3184186, '6d5cb5353d700dfd4e6380508535e41b': 1.3184186, 'cf481e35314bb3e93f3716b6c456e85d': 1.3184186, '0610

loss: 0.48054, smooth loss: 0.50115: 100%|██████████| 3610/3610 [26:52<00:00,  2.24it/s]
loss: 0.39679, smooth loss: 0.56453: 100%|██████████| 3610/3610 [26:51<00:00,  2.24it/s]
loss: 0.99262, smooth loss: 0.57822: 100%|██████████| 3610/3610 [26:48<00:00,  2.24it/s]
loss: 0.40883, smooth loss: 0.56564: 100%|██████████| 3610/3610 [26:48<00:00,  2.24it/s]
loss: 0.70930, smooth loss: 0.57625: 100%|██████████| 3610/3610 [26:47<00:00,  2.25it/s]
loss: 0.41037, smooth loss: 0.55649: 100%|██████████| 3610/3610 [26:47<00:00,  2.25it/s]
loss: 0.42756, smooth loss: 0.59433: 100%|██████████| 3610/3610 [26:48<00:00,  2.24it/s]
loss: 0.40950, smooth loss: 0.55365: 100%|██████████| 3610/3610 [26:48<00:00,  2.24it/s]
loss: 0.40763, smooth loss: 0.58667: 100%|██████████| 3610/3610 [26:47<00:00,  2.25it/s]
loss: 0.56686, smooth loss: 0.57917: 100%|██████████| 3610/3610 [26:48<00:00,  2.24it/s]
100%|██████████| 903/903 [02:08<00:00,  7.05it/s]


images {'991b6a11393bf5d10e91a50f581cba68': 1.3184186, '7a5f4e79d1efd7b32518909b00753bd5': 1.3184186, '506eec1a2c0a806ef10ff76a9934cfe0': 1.3184186, '0470680842880ac7493bbf5e7cfbc26b': 1.3184186, '93b3300fd244ac326a9f11a6a47ed016': 1.3184186, 'dd26b0dd91cb20493617a9b82ab9d8fd': 1.3184186, 'b272895e4dc95112031833596f94621a': 1.3184186, 'ac0d8bc88d9ec900be978a70b90e425b': 1.3184186, '65b25c89a74ee86c54185060a6bfc25b': 1.3184186, '8364aaf55439cd2b1207b7bfa35d79a2': 1.3184186, 'ab7f757a2343686945a621a662b6f667': 1.3184186, '8dbbf6f6aebb8aea3339f70dbfd7cf07': 1.3184186, '48896f2b866e9694586b3377025c4927': 1.3184186, 'd49ec86c92a9a403ccd971a0e12fe734': 1.3184186, 'f709898cdc46cedb62cef65b87039218': 1.3184186, '331df8bf3a7f766dfe45ed525040fa85': 1.3184186, 'e7ded102214a944d051a2628d5452faa': 1.3184186, '8addcf68a92fde8dba0b2a99caf9bff9': 1.3184186, 'e2876d9ebfde8e9b19701ac52b33dcd4': 1.3184186, '6d5cb5353d700dfd4e6380508535e41b': 1.3184186, 'cf481e35314bb3e93f3716b6c456e85d': 1.3184186, '0610

loss: 0.70448, smooth loss: 0.51949: 100%|██████████| 3610/3610 [26:50<00:00,  2.24it/s]
loss: 0.44012, smooth loss: 0.57850: 100%|██████████| 3610/3610 [27:18<00:00,  2.20it/s]
loss: 0.41082, smooth loss: 0.57931: 100%|██████████| 3610/3610 [27:56<00:00,  2.15it/s]
loss: 0.55089, smooth loss: 0.57347: 100%|██████████| 3610/3610 [28:29<00:00,  2.11it/s]
loss: 0.70876, smooth loss: 0.58517: 100%|██████████| 3610/3610 [28:05<00:00,  2.14it/s]
loss: 0.56686, smooth loss: 0.57547: 100%|██████████| 3610/3610 [28:34<00:00,  2.11it/s]
loss: 0.40366, smooth loss: 0.55735: 100%|██████████| 3610/3610 [29:16<00:00,  2.06it/s]
loss: 0.56689, smooth loss: 0.57968: 100%|██████████| 3610/3610 [29:55<00:00,  2.01it/s]
loss: 0.40820, smooth loss: 0.54990: 100%|██████████| 3610/3610 [30:59<00:00,  1.94it/s]
loss: 0.40723, smooth loss: 0.55983: 100%|██████████| 3610/3610 [30:10<00:00,  1.99it/s]
100%|██████████| 903/903 [02:16<00:00,  6.64it/s]


images {'991b6a11393bf5d10e91a50f581cba68': 1.3184186, '7a5f4e79d1efd7b32518909b00753bd5': 1.3184186, '506eec1a2c0a806ef10ff76a9934cfe0': 1.3184186, '0470680842880ac7493bbf5e7cfbc26b': 1.3184186, '93b3300fd244ac326a9f11a6a47ed016': 1.3184186, 'dd26b0dd91cb20493617a9b82ab9d8fd': 1.3184186, 'b272895e4dc95112031833596f94621a': 1.3184186, 'ac0d8bc88d9ec900be978a70b90e425b': 1.3184186, '65b25c89a74ee86c54185060a6bfc25b': 1.3184186, '8364aaf55439cd2b1207b7bfa35d79a2': 1.3184186, 'ab7f757a2343686945a621a662b6f667': 1.3184186, '8dbbf6f6aebb8aea3339f70dbfd7cf07': 1.3184186, '48896f2b866e9694586b3377025c4927': 1.3184186, 'd49ec86c92a9a403ccd971a0e12fe734': 1.3184186, 'f709898cdc46cedb62cef65b87039218': 1.3184186, '331df8bf3a7f766dfe45ed525040fa85': 1.3184186, 'e7ded102214a944d051a2628d5452faa': 1.3184186, '8addcf68a92fde8dba0b2a99caf9bff9': 1.3184186, 'e2876d9ebfde8e9b19701ac52b33dcd4': 1.3184186, '6d5cb5353d700dfd4e6380508535e41b': 1.3184186, 'cf481e35314bb3e93f3716b6c456e85d': 1.3184186, '0610

loss: 0.58304, smooth loss: 0.56381: 100%|██████████| 3610/3610 [27:13<00:00,  2.21it/s]
loss: 0.58137, smooth loss: 0.57498: 100%|██████████| 3610/3610 [27:12<00:00,  2.21it/s]
loss: 1.00608, smooth loss: 0.56334: 100%|██████████| 3610/3610 [27:10<00:00,  2.21it/s]
loss: 0.42102, smooth loss: 0.54996: 100%|██████████| 3610/3610 [27:08<00:00,  2.22it/s]
loss: 0.52594, smooth loss: 0.57120: 100%|██████████| 3610/3610 [27:44<00:00,  2.17it/s]
loss: 0.55783, smooth loss: 0.56038: 100%|██████████| 3610/3610 [26:50<00:00,  2.24it/s]
loss: 0.42481, smooth loss: 0.56748: 100%|██████████| 3610/3610 [25:43<00:00,  2.34it/s]
loss: 0.40490, smooth loss: 0.55963: 100%|██████████| 3610/3610 [25:37<00:00,  2.35it/s]
loss: 0.40455, smooth loss: 0.54567: 100%|██████████| 3610/3610 [25:36<00:00,  2.35it/s]
loss: 0.40563, smooth loss: 0.59151: 100%|██████████| 3610/3610 [25:35<00:00,  2.35it/s]
100%|██████████| 903/903 [02:03<00:00,  7.32it/s]


images {'991b6a11393bf5d10e91a50f581cba68': 1.3184186, '7a5f4e79d1efd7b32518909b00753bd5': 1.3184186, '506eec1a2c0a806ef10ff76a9934cfe0': 1.3184186, '0470680842880ac7493bbf5e7cfbc26b': 1.3184186, '93b3300fd244ac326a9f11a6a47ed016': 1.3184186, 'dd26b0dd91cb20493617a9b82ab9d8fd': 1.3184186, 'b272895e4dc95112031833596f94621a': 1.3184186, 'ac0d8bc88d9ec900be978a70b90e425b': 1.3184186, '65b25c89a74ee86c54185060a6bfc25b': 1.3184186, '8364aaf55439cd2b1207b7bfa35d79a2': 1.3184186, 'ab7f757a2343686945a621a662b6f667': 1.3184186, '8dbbf6f6aebb8aea3339f70dbfd7cf07': 1.3184186, '48896f2b866e9694586b3377025c4927': 1.3184186, 'd49ec86c92a9a403ccd971a0e12fe734': 1.3184186, 'f709898cdc46cedb62cef65b87039218': 1.3184186, '331df8bf3a7f766dfe45ed525040fa85': 1.3184186, 'e7ded102214a944d051a2628d5452faa': 1.3184186, '8addcf68a92fde8dba0b2a99caf9bff9': 1.3184186, 'e2876d9ebfde8e9b19701ac52b33dcd4': 1.3184186, '6d5cb5353d700dfd4e6380508535e41b': 1.3184186, 'cf481e35314bb3e93f3716b6c456e85d': 1.3184186, '0610

loss: 0.80060, smooth loss: 0.55012: 100%|██████████| 3610/3610 [25:38<00:00,  2.35it/s]
loss: 0.55173, smooth loss: 0.58136: 100%|██████████| 3610/3610 [25:38<00:00,  2.35it/s]
loss: 0.50197, smooth loss: 0.57680: 100%|██████████| 3610/3610 [25:40<00:00,  2.34it/s]
loss: 0.48666, smooth loss: 0.54661: 100%|██████████| 3610/3610 [25:35<00:00,  2.35it/s]
loss: 0.56067, smooth loss: 0.54210: 100%|██████████| 3610/3610 [25:34<00:00,  2.35it/s]
loss: 0.49603, smooth loss: 0.56563: 100%|██████████| 3610/3610 [25:35<00:00,  2.35it/s]
loss: 0.41554, smooth loss: 0.56841: 100%|██████████| 3610/3610 [25:41<00:00,  2.34it/s]
loss: 0.41986, smooth loss: 0.56920: 100%|██████████| 3610/3610 [25:35<00:00,  2.35it/s]
loss: 0.48887, smooth loss: 0.59611: 100%|██████████| 3610/3610 [25:35<00:00,  2.35it/s]
loss: 0.56706, smooth loss: 0.60734: 100%|██████████| 3610/3610 [25:34<00:00,  2.35it/s]
100%|██████████| 902/902 [02:03<00:00,  7.32it/s]

images {'991b6a11393bf5d10e91a50f581cba68': 1.3184186, '7a5f4e79d1efd7b32518909b00753bd5': 1.3184186, '506eec1a2c0a806ef10ff76a9934cfe0': 1.3184186, '0470680842880ac7493bbf5e7cfbc26b': 1.3184186, '93b3300fd244ac326a9f11a6a47ed016': 1.3184186, 'dd26b0dd91cb20493617a9b82ab9d8fd': 1.3184186, 'b272895e4dc95112031833596f94621a': 1.3184186, 'ac0d8bc88d9ec900be978a70b90e425b': 1.3184186, '65b25c89a74ee86c54185060a6bfc25b': 1.3184186, '8364aaf55439cd2b1207b7bfa35d79a2': 1.3184186, 'ab7f757a2343686945a621a662b6f667': 1.3184186, '8dbbf6f6aebb8aea3339f70dbfd7cf07': 1.3184186, '48896f2b866e9694586b3377025c4927': 1.3184186, 'd49ec86c92a9a403ccd971a0e12fe734': 1.3184186, 'f709898cdc46cedb62cef65b87039218': 1.3184186, '331df8bf3a7f766dfe45ed525040fa85': 1.3184186, 'e7ded102214a944d051a2628d5452faa': 1.3184186, '8addcf68a92fde8dba0b2a99caf9bff9': 1.3184186, 'e2876d9ebfde8e9b19701ac52b33dcd4': 1.3184186, '6d5cb5353d700dfd4e6380508535e41b': 1.3184186, 'cf481e35314bb3e93f3716b6c456e85d': 1.3184186, '0610




In [8]:
df_remove_images = pd.DataFrame(list(images.items()), columns=['image_id', 'entropy'])

df_remove_images.to_csv("data/remove-images.csv")
df_remove_images.shape

(9024, 2)