In [1]:
import UQ_toolbox as uq
from medMNIST.utils import train_load_datasets_resnet as tr
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedKFold
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
from gps_augment.utils.randaugment import BetterRandAugment


def train_val_loaders(train_dataset, batch_size):
    # Create stratified K-fold cross-validator
    n_splits = 5
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Get the labels for stratification
    labels = [label for _, label in train_dataset]

    # Create a list to store the new dataloaders
    train_loaders = []
    val_loaders = []

    for train_index, val_index in skf.split(np.zeros(len(labels)), labels):
        train_subset = torch.utils.data.Subset(train_dataset, train_index)
        val_subset = torch.utils.data.Subset(train_dataset, val_index)
        
        train_loader = DataLoader(dataset=train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
        val_loader = DataLoader(dataset=val_subset, batch_size=batch_size, shuffle=True, drop_last=True)
        
        train_loaders.append(train_loader)
        val_loaders.append(val_loader)
    return train_loaders, val_loaders

dataflag = 'breastmnist'
color = False # True for color, False for grayscale
activation = 'sigmoid'
batch_size = 4000
im_size = 224
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
size = 224  # Image size for the models
batch_size = 4000  # Batch size for the DataLoader

print(f"Processing {dataflag} with color={color} and activation={activation}")
if color is True:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
    ])
    
    transform_tta = transforms.Compose([
        transforms.ToTensor()
    ])
else:
    # For grayscale images, repeat the single channel to make it compatible with ResNet
    # ResNet expects 3 channels, so we repeat the single channel image
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5]),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    ])
    
    transform_tta = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
        ])
models = tr.load_models(dataflag, device=device)
[train_dataset, calibration_dataset, test_dataset], [train_loader, calibration_loader, test_loader], info = tr.load_datasets(dataflag, color, size, transform, batch_size)
train_loaders, val_loaders = train_val_loaders(train_dataset, batch_size=batch_size)
task_type = info['task']  # Determine the task type (binary-class or multi-class)
num_classes = len(info['label'])  # Number of classes
[_, calibration_dataset_tta, test_dataset_tta], [_, calibration_loader_tta, test_loader_tta], _ = tr.load_datasets(dataflag, color, size, transform_tta, batch_size)


Processing breastmnist with color=False and activation=sigmoid
Training dataset size: 499
Calibration dataset size: 125
Training dataset size: 499
Calibration dataset size: 125


In [None]:
Manually set transform. Current transform: 
[(14,np.float64(34.55240317163067)),(5,np.float64(-41.84148881205453))]
Manually set transform. Current transform: 
[(9,np.float64(-40.01514400895359)),(6,np.float64(28.997648027972176))]
Applying augmentation n : 0
Applying augmentation n : 1
Manually set transform. Current transform: 
[(3,np.float64(22.304599753771143)),(8,np.float64(40.295612709687944))]
Manually set transform. Current transform: 
[(9,np.float64(-40.01514400895359)),(6,np.float64(28.997648027972176))]
Applying augmentation n : 0
Applying augmentation n : 1
Manually set transform. Current transform: 
[(2,np.float64(23.66661945829418)),(8,np.float64(31.904003985672745))]
Manually set transform. Current transform: 
[(9,np.float64(-40.01514400895359)),(6,np.float64(28.997648027972176))]

In [3]:
def apply_randaugment_and_store_results(
    dataset, models, N, M, num_policies, device, folder_name='savedpolicies',
    image_normalization=False, mean=False, std=False, nb_channels=1, image_size=51,
    output_activation=None, batch_size=None
):
    """
    Apply RandAugment transformations to the data and store the results, one augmentation at a time.
    """
    
    os.makedirs(folder_name, exist_ok=True)

    for i in range(num_policies):
    
        print(f"Applying augmentation policy {i+1}/{num_policies}")
        # Apply augmentation and get augmented images
        apply_augmentations(
            dataset, 1, True, N, M, image_normalization, nb_channels, mean, std, image_size, batch_size=batch_size
        )
        break


In [18]:
apply_randaugment_and_store_results(calibration_dataset, models, 2, 45, 500, device, folder_name=f'/mnt/data/psteinmetz/archive_notebooks/Documents/medMNIST/gps_augment/{im_size}*{im_size}/{dataflag}_calibration_set', image_normalization=True, mean=[.5], std=[.5], image_size=im_size, nb_channels=3, output_activation=activation, batch_size=batch_size)

Applying augmentation policy 1/500
Applying augmentation n : 0
[(<function SolarizeAdd at 0x727acbbb1ee0>, 256, 128), (<function Cutout at 0x727acbbb22a0>, 0, 0.5)]


TypeError: color must be int or single-element tuple

In [22]:
len(calibration_dataset)

125

In [23]:
calibration_dataset.dataset.datasets

[Dataset BreastMNIST of size 224 (breastmnist_224)
     Number of datapoints: 546
     Root location: /home/psteinmetz/.medmnist
     Split: train
     Task: binary-class
     Number of channels: 1
     Meaning of labels: {'0': 'malignant', '1': 'normal, benign'}
     Number of samples: {'train': 546, 'val': 78, 'test': 156}
     Description: The BreastMNIST is based on a dataset of 780 breast ultrasound images. It is categorized into 3 classes: normal, benign, and malignant. As we use low-resolution images, we simplify the task into binary classification by combining normal and benign as positive and classifying them against malignant as negative. We split the source dataset with a ratio of 7:1:2 into training, validation and test set. The source images of 1×500×500 are resized into 1×28×28.
     License: CC BY 4.0,
 Dataset BreastMNIST of size 224 (breastmnist_224)
     Number of datapoints: 78
     Root location: /home/psteinmetz/.medmnist
     Split: val
     Task: binary-class
   