In [1]:
%load_ext autoreload
%autoreload 2

from argparse import ArgumentParser

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch

from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from PIL import Image
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities.cloud_io import load as pl_load

from data_utils import Center0Dataset, calculate_stats, ImbalancedDatasetSampler, MultipleCentersSeq, OneCenterLoad
from augmentations import geom_augmentations, basic_augmentations, color_augmentations, no_augmentations, gan_augmentations, normalization
from model import Classifier

In [2]:
class Args():
    learning_rate = 1e-5
    l2_reg = 1e-5
    batch_size = 8
    gpus = 1
    fast_dev_run = True
    name = 'debug'
    description = 'weighted'
    weighted = True
    max_epochs = 1
args = Args()

In [3]:
data_dir = '/home/haicu/sophia.wagner/datasets/2101_camelyon17/'
data_file = '/home/haicu/sophia.wagner/datasets/2101_camelyon17/center0_level2.hdf5'
# %time train_dataset = OneCenterLoad(data_dir, 0, 'train', transform=gan_augmentations)
%time val_dataset = OneCenterLoad(data_dir, 0, 'val', transform=basic_augmentations)
# %time train_dataset = Center0Dataset(data_file, 'train', transform=None)
# %time val_dataset = Center0Dataset(data_file, 'val', transform=None)

CPU times: user 2.52 s, sys: 11.8 s, total: 14.3 s
Wall time: 14.9 s


In [4]:
# train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=6)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=6)

In [5]:
# checkpoint_callback = ModelCheckpoint(
#     monitor='AUC',
#     dirpath=logger.log_dir + '/checkpoints/',
#     filename='Classifier-Center0-{epoch:02d}-{AUC:.2f}',
#     save_top_k=1,
#     mode='max'
# )

# early_stop_callback = EarlyStopping(
#    monitor='val_metrics/AUC',
#    min_delta=0.01,
#    patience=3,
#    verbose=False,
#    mode='max'
#     )

In [8]:
# parser = Classifier.add_model_specific_args(parser)
model = Classifier(args.learning_rate, args.l2_reg, weighted=args.weighted, mdmm_aug=True)
# parser = pl.Trainer.add_argparse_args(args)
logger = TensorBoardLogger('lightning_logs', name=args.name)

trainer = pl.Trainer(
    gpus = args.gpus,
    logger = logger,
#     callbacks = [checkpoint_callback, early_stop_callback],
#     fast_dev_run = args.fast_dev_run,
    weights_summary=None,
    max_epochs = args.max_epochs,
    log_every_n_steps = 10,
#     num_sanity_val_steps=0,
#         limit_train_batches=0.1,
#         limit_val_batches=0.5,
)
trainer.fit(model, val_loader, val_loader)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.


mdmm initialized


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size([8, 3, 512, 512]) torch.float32
torch.Size(






1

In [13]:
checkpoint_path = '/home/haicu/sophia.wagner/projects/stain_color/stain_aug/lightning_logs/basic_augmentations/version_12/checkpoints/Classifier-Center0-epoch=11-PR_AUC=0.96.ckpt'
# checkpoint_path = '/home/haicu/sophia.wagner/projects/stain_color/stain_aug/lightning_logs/debug/version_29/checkpoints/epoch=0-step=240.ckpt'

In [14]:
logger = TensorBoardLogger('lightning_logs', name=args.name)
print(logger.log_dir)

trainer = pl.Trainer(
    gpus = args.gpus,
    logger = logger,
#     callbacks = [checkpoint_callback, early_stop_callback],
#     fast_dev_run = args.fast_dev_run,
    weights_summary=None,
#         limit_train_batches=0.1,
#     limit_test_batches=5,
)


center = 0
model = Classifier(args.learning_rate, args.l2_reg, weighted=args.weighted, mdmm_norm=center)
# model = model.load_from_checkpoint(checkpoint_path=checkpoint_path)

checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
# model.model.load_state_dict(checkpoint['state_dict'])
model.model.load_state_dict(torch.load(checkpoint_path), strict=False)
model.mdmm_norm = center

test_centers = [[i,] for i in range(5)]
#         test_centers.remove([center,])
test_all = list(range(5))
test_all.remove(center)
test_centers.append(test_all)

results = []
for c in test_centers:
    print(f'results for dataset {c}')
    if c == [center, ]:
#         test_dataset = OneCenterLoad('/home/haicu/sophia.wagner/datasets/2101_camelyon17/', center, 'val')
        continue
    else:
        test_dataset = MultipleCentersSeq('/home/haicu/sophia.wagner/datasets/2101_camelyon17/', c)
    test_loader = DataLoader(test_dataset, batch_size=8, num_workers=1)
    result = trainer.test(model, test_dataloaders=test_loader)
    results.append(result)

print(test_centers)
print('PR_AUC')
pr_auc = [round(res[0]['PR_AUC'], 4) for res in results]
print(pr_auc)
print('F1_tumor')
f1 = [round(res[0]['F1_tumor'], 4) for res in results]
print(f1)



GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


lightning_logs/debug/version_43
mdmm initialized
results for dataset [0]
results for dataset [1]


Set SLURM handle signals.


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized b




--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'F1_normal': 0.0355382040143013,
 'F1_tumor': 0.05234580487012863,
 'PR_AUC': 0.021323878318071365,
 'confusion_matrix_00': 537.0,
 'confusion_matrix_01': 28858.0,
 'confusion_matrix_10': 289.0,
 'confusion_matrix_11': 805.0,
 'precision_normal': 0.6501210927963257,
 'precision_tumor': 0.027138184756040573,
 'recall_normal': 0.018268413841724396,
 'recall_tumor': 0.7358317971229553,
 'test_loss': 1.1935093402862549,
 'test_metrics/F1_normal': 0.0355382040143013,
 'test_metrics/F1_tumor': 0.05234580487012863,
 'test_metrics/PR_AUC': 0.021323878318071365,
 'test_metrics/acc': 0.04401587322354317,
 'test_metrics/precision_normal': 0.6501210927963257,
 'test_metrics/precision_tumor': 0.027138184756040573,
 'test_metrics/recall_normal': 0.018268413841724396,
 'test_metrics/recall_tumor': 0.7358317971229553}
-----------------------------------------------------------------------------

Set SLURM handle signals.


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized by mdmm
normalized b

KeyboardInterrupt: 

In [15]:
t = torch.load(checkpoint_path)

In [16]:
t

{'epoch': 12,
 'global_step': 1897,
 'pytorch-lightning_version': '1.1.8',
 'callbacks': {pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint: {'monitor': 'PR_AUC',
   'best_model_score': tensor(0.9594, device='cuda:0'),
   'best_model_path': '/mnt/home/icb/sophia.wagner/projects/stain_color/stain_aug/lightning_logs/basic_augmentations/version_12/checkpoints/Classifier-Center0-epoch=07-PR_AUC=0.96.ckpt',
   'current_score': tensor(0.9583, device='cuda:0'),
   'dirpath': '/mnt/home/icb/sophia.wagner/projects/stain_color/stain_aug/lightning_logs/basic_augmentations/version_12/checkpoints'},
  pytorch_lightning.callbacks.early_stopping.EarlyStopping: {'wait_count': 6,
   'stopped_epoch': 0,
   'best_score': tensor(0.9594, device='cuda:0'),
   'patience': 20}},
 'optimizer_states': [{'state': {30: {'step': 1897,
     'exp_avg': tensor([[[[-3.6025e-04, -4.0131e-04, -4.2214e-04],
               [ 1.3477e-04, -4.1019e-04, -1.9500e-04],
               [ 1.8899e-05, -3.4020e-05, -7.545

In [10]:
checkpoint['state_dict'].keys()

odict_keys(['model.conv1.weight', 'model.bn1.weight', 'model.bn1.bias', 'model.bn1.running_mean', 'model.bn1.running_var', 'model.bn1.num_batches_tracked', 'model.layer1.0.conv1.weight', 'model.layer1.0.bn1.weight', 'model.layer1.0.bn1.bias', 'model.layer1.0.bn1.running_mean', 'model.layer1.0.bn1.running_var', 'model.layer1.0.bn1.num_batches_tracked', 'model.layer1.0.conv2.weight', 'model.layer1.0.bn2.weight', 'model.layer1.0.bn2.bias', 'model.layer1.0.bn2.running_mean', 'model.layer1.0.bn2.running_var', 'model.layer1.0.bn2.num_batches_tracked', 'model.layer1.1.conv1.weight', 'model.layer1.1.bn1.weight', 'model.layer1.1.bn1.bias', 'model.layer1.1.bn1.running_mean', 'model.layer1.1.bn1.running_var', 'model.layer1.1.bn1.num_batches_tracked', 'model.layer1.1.conv2.weight', 'model.layer1.1.bn2.weight', 'model.layer1.1.bn2.bias', 'model.layer1.1.bn2.running_mean', 'model.layer1.1.bn2.running_var', 'model.layer1.1.bn2.num_batches_tracked', 'model.layer2.0.conv1.weight', 'model.layer2.0.bn1.w

In [19]:
augmentations[None]

Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

In [40]:
center = 4

In [41]:
    test_centers = [[i,] for i in range(5)]
    test_centers.remove([center,])
    all = list(range(5))
    all.remove(center)
    test_centers.append(all)

In [42]:
test_centers

[[0], [1], [2], [3], [0, 1, 2, 3]]

In [None]:
data_dir = '/home/haicu/sophia.wagner/datasets/2101_camelyon17/'

for c in tqdm([[0, 1, 2, 3]]):
    dataset = MultipleCentersSeq(data_dir, c, transform=None)
    loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=1)
    mean, std = calculate_stats(loader, len(dataset))
    print(c)
    print(mean)
    print(std)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5808.0), HTML(value='')))

In [1]:
test = 0.36983745983745983745987

In [2]:
test

0.36983745983745986

In [3]:
round(test, 4)

0.3698

In [3]:
t = torch.randint(10, (10,))

In [4]:
t

tensor([4, 7, 0, 3, 7, 7, 2, 4, 6, 3])

In [5]:
t.bool()

tensor([ True,  True, False,  True,  True,  True,  True,  True,  True,  True])

In [71]:
patches = []
labels = []
domains = []
data_dir = '/home/haicu/sophia.wagner/datasets/2101_camelyon17/'
for c in tqdm(range(5)):
    data = {
        'patches': [],
        'tumor_ratio': [],
    }
    file_path = data_dir + f'center{c}_level2.hdf5'
    with h5py.File(file_path, 'r') as h5_file:
        # Walk through all groups, extracting datasets
        for gname, group in h5_file.items():
            data['patches'].append(torch.from_numpy(group['patches'][:]))
            data['tumor_ratio'].append(torch.from_numpy(group['tumor_ratio'][:]))    
    data['patches'] = torch.cat(data['patches']).permute(0, 3, 1, 2)
    data['tumor_ratio'] = torch.cat(data['tumor_ratio']).unsqueeze(-1)

    indices = torch.randint(0, len(data['patches']), (1000, ))
    patches.append(data['patches'][indices])
    labels.append((data['tumor_ratio'][indices] >= 0.01).byte())
    domains.append(torch.ones((1000, )) * c)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [72]:
patches = torch.cat(patches)
labels = torch.cat(labels)
domains = torch.cat(domains)

In [73]:
patches.shape

torch.Size([5000, 3, 512, 512])

In [78]:
labels.unique(return_counts=True)

(tensor([0, 1], dtype=torch.uint8), tensor([4672,  328]))

In [76]:
domains.unique()

tensor([0., 1., 2., 3., 4.])

In [81]:
img = Image.fromarray(patches[2349].permute(1, 2, 0).numpy())
img.save('test.png')

In [10]:
indices = np.random.randint(0, (len(train_dataset) + len(val_dataset)), 1000)

In [82]:
image_dir = '/home/haicu/sophia.wagner/datasets/sample_camelyon17_patches/'

In [84]:
csv = np.zeros((5000, 3))
for i in tqdm(range(5000)):
    img = Image.fromarray(patches[i].permute(1, 2, 0).numpy())
    img.save(image_dir + f'{i}.png')
    csv[i, 0] = i
    csv[i, 1] = labels[i]
    csv[i, 2] = domains[i]

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))




In [85]:
csv.shape

(5000, 3)

In [11]:
del train_dataset, val_dataset

In [86]:
np.savetxt(image_dir + 'annotations.csv', csv, delimiter=',')

In [23]:
data['tumor_ratio'][indices].unique()

tensor([0.], dtype=torch.float64)

In [47]:
y = torch.ones(1) * 0.0005

In [48]:
y = (y >= 0.01).int()

In [68]:
len(data['patches'])

4136

In [70]:
indices.shape

torch.Size([1000])

In [50]:
y

tensor([0], dtype=torch.int32)

In [8]:
test = 1

In [9]:
if test in range(5):
    print(test)

1


In [15]:
            mean_domains = torch.Tensor([
                [ 0.3020, -2.6476, -0.9849, -0.7820, -0.2746,  0.3361,  0.1694, -1.2148],
                [ 0.1453, -1.2400, -0.9484,  0.9697, -2.0775,  0.7676, -0.5224, -0.2945],
                [ 2.1067, -1.8572,  0.0055,  1.2214, -2.9363,  2.0249, -0.4593, -0.9771],
                [ 0.8378, -2.1174, -0.6531,  0.2986, -1.3629, -0.1237, -0.3486, -1.0716],
                [ 1.6073,  1.9633, -0.3130, -1.9242, -0.9673,  2.4990, -2.2023, -1.4109],
            ])

            
            std_domains = torch.Tensor([
                [0.6550, 1.5427, 0.5444, 0.7254, 0.6701, 1.0214, 0.6245, 0.6886],
                [0.4143, 0.6543, 0.5891, 0.4592, 0.8944, 0.7046, 0.4441, 0.3668],
                [0.5576, 0.7634, 0.7875, 0.5220, 0.7943, 0.8918, 0.6000, 0.5018],
                [0.4157, 0.4104, 0.5158, 0.3498, 0.2365, 0.3612, 0.3375, 0.4214],
                [0.6154, 0.3440, 0.7032, 0.6220, 0.4496, 0.6488, 0.4886, 0.2989],
            ])

In [14]:
domain.shape

torch.Size([5])

In [19]:
std_domains[3].shape

torch.Size([8])

In [40]:
            z_attr = (torch.ones((3, 8, )) * std_domains[3] + mean_domains[3])

In [41]:
z_attr.shape

torch.Size([3, 8])

In [42]:
z_attr

tensor([[ 1.2535, -1.7070, -0.1373,  0.6484, -1.1264,  0.2375, -0.0111, -0.6502],
        [ 1.2535, -1.7070, -0.1373,  0.6484, -1.1264,  0.2375, -0.0111, -0.6502],
        [ 1.2535, -1.7070, -0.1373,  0.6484, -1.1264,  0.2375, -0.0111, -0.6502]])

In [37]:
domain = torch.eye(5)[3].unsqueeze(0).repeat(128, 1)

In [39]:
domain[:10]

tensor([[0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.]])