In [1]:
import UQ_toolbox as uq
from medMNIST.utils import train_load_datasets_resnet as tr
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedKFold
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
from gps_augment.utils.randaugment import BetterRandAugment


def train_val_loaders(train_dataset, batch_size):
    # Create stratified K-fold cross-validator
    n_splits = 5
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    # Get the labels for stratification
    labels = [label for _, label in train_dataset]

    # Create a list to store the new dataloaders
    train_loaders = []
    val_loaders = []

    for train_index, val_index in skf.split(np.zeros(len(labels)), labels):
        train_subset = torch.utils.data.Subset(train_dataset, train_index)
        val_subset = torch.utils.data.Subset(train_dataset, val_index)
        
        train_loader = DataLoader(dataset=train_subset, batch_size=batch_size, shuffle=True, drop_last=True)
        val_loader = DataLoader(dataset=val_subset, batch_size=batch_size, shuffle=True, drop_last=True)
        
        train_loaders.append(train_loader)
        val_loaders.append(val_loader)
    return train_loaders, val_loaders

dataflag = 'breastmnist'
color = False # True for color, False for grayscale
activation = 'sigmoid'
batch_size = 4000
im_size = 224
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
size = 224  # Image size for the models
batch_size = 4000  # Batch size for the DataLoader

print(f"Processing {dataflag} with color={color} and activation={activation}")
if color is True:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
    ])
    
    transform_tta = transforms.Compose([
        transforms.ToTensor()
    ])
else:
    # For grayscale images, repeat the single channel to make it compatible with ResNet
    # ResNet expects 3 channels, so we repeat the single channel image
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5], std=[.5]),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    ])
    
    transform_tta = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
        ])
models = tr.load_models(dataflag, device=device)
[train_dataset, calibration_dataset, test_dataset], [train_loader, calibration_loader, test_loader], info = tr.load_datasets(dataflag, color, size, transform, batch_size)
train_loaders, val_loaders = train_val_loaders(train_dataset, batch_size=batch_size)
task_type = info['task']  # Determine the task type (binary-class or multi-class)
num_classes = len(info['label'])  # Number of classes
[_, calibration_dataset_tta, test_dataset_tta], [_, calibration_loader_tta, test_loader_tta], _ = tr.load_datasets(dataflag, color, size, transform_tta, batch_size)


Processing breastmnist with color=False and activation=sigmoid
Training dataset size: 499
Calibration dataset size: 125
Training dataset size: 499
Calibration dataset size: 125


In [None]:
class AddBatchDimension:
    def __call__(self, image):
        # Ensure the image is a tensor and add batch dimension
        if isinstance(image, torch.Tensor):
            return image.unsqueeze(0).float()
        raise TypeError("Input should be a torch Tensor")


def get_prediction(model, image, device):
    """
    Generates a prediction from a given model and image.

    Args:
        model (torch.nn.Module): The model used for prediction.
        image (torch.Tensor): The input image tensor.
        device (torch.device): The device to run the model on (e.g., 'cpu' or 'cuda').
        softmax_application (bool, optional): If True, applies softmax to the prediction. Defaults to False.

    Returns:
        torch.Tensor: The prediction output from the model.
    """
    model.to(device)
    image = image.to(device, non_blocking=True)
    model.eval()  # Ensure the model is in evaluation mode
    with torch.no_grad():  # Disable gradient computation
        prediction = model(image).detach().cpu()

    return prediction


def extract_gps_augmentations_info(policies):
    """
    Extracts N, M values and the list of policies from a list of policy filenames.

    Args:
    - policies (list of str): List of filenames in the format 'N2_M45_[(op, magnitude), (op, magnitude)].npz'.

    Returns:
    - N (int): The value of N (same for all policies).
    - M (int): The value of M (same for all policies).
    - formatted_policies (list of str): List of policies as strings.
    """
    # Extract N and M from the first policy string
    if not policies:
        return None, None, []

    first_policy = policies[0]
    match = re.search(r'N(\d+)_M(\d+)', first_policy)
    if match:
        N = int(match.group(1))
        M = int(match.group(2))
    
    formatted_policies = []

    # Extract the policy part from each filename and keep it as a string
    for policy in policies:
        policy_match = re.search(r'\[(.*?)\]', policy)
        if policy_match:
            policy_str = f"[{policy_match.group(1)}]"  # Add brackets back around the tuple
            formatted_policies.append(policy_str)

    return N, M, formatted_policies


def to_3_channels(img):
    if img.mode == 'L':  # Grayscale image
        img = img.convert('RGB')  # Convert to 3 channels by duplicating
    return img

def to_1_channel(img):
    img = img.convert('L')  # Convert back to grayscale
    return img


def apply_augmentations(dataset, nb_augmentations, usingBetterRandAugment, n, m, image_normalization, nb_channels, mean, std, image_size, transformations=False, batch_size=None):
    """
    Apply augmentations to the images.

    Args:
        images (torch.Tensor): Batch of images.
        transformations (callable or list): Transformations to apply to each image.
        nb_augmentations (int): Number of augmentations to apply per image.
        usingBetterRandAugment (bool): If True, use BetterRandAugment with provided policies.
        n (int): Number of augmentation transformations to apply when using BetterRandAugment.
        m (int): Magnitude of the augmentation transformations when using BetterRandAugment.
        batch_norm (bool): Whether to use batch normalization.
        nb_channels (int): Number of channels in the input images.
        mean (list or None): Mean for normalization.
        std (list or None): Standard deviation for normalization.
        image_size (int): Size of the input images.

    Returns:
        torch.Tensor: Augmented images.
    """
    augmented_inputs = []
    if usingBetterRandAugment:
        if isinstance(transformations, list):
            rand_aug_policies = [BetterRandAugment(n=n, m=m, resample=False, transform=policy, verbose=True, randomize_sign=False, image_size=image_size) for policy in transformations]
            
        elif transformations is False:
            rand_aug_policies = [BetterRandAugment(n, m, True, False, randomize_sign=False, image_size=image_size) for _ in range(nb_augmentations)] 

        augmentations = [transforms.Compose([
                    transforms.ToTensor(),
                    transforms.ToPILImage(),
                    transforms.Lambda(lambda img: img.convert("RGB")),  # Ensure image is in RGB format
                    *([to_3_channels] if nb_channels == 1 else []),  # Conditionally add to_3_channels
                    rand_aug,
                    *([to_1_channel] if nb_channels == 1 else []),  # Conditionally add to_1_channel
                    transforms.PILToTensor(),
                    transforms.Lambda(lambda x: x.float()) if nb_channels == 1 else transforms.ConvertImageDtype(torch.float),
                    *([transforms.Normalize(mean=mean, std=std)] if image_normalization else [])
                ]) for rand_aug in rand_aug_policies]
        
        for i, augmentation in enumerate(augmentations):
            print(f"Applying augmentation n : {i}")

            for subds in dataset.dataset.datasets:
                subds.transform = transforms.Compose([transforms.ToTensor()])
            # Get first batch of original images
            #dataset.dataset.datasets.transform = transforms.Compose([transforms.ToTensor()])
            data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
            batch = next(iter(data_loader))
            original_images = batch[0]

            for subds in dataset.dataset.datasets:
                subds.transform = augmentation
            # Get first batch of augmented images
            #dataset.dataset.datasets.transform = augmentation
            print(augmentation.transforms[3].ops)
            data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
            batch = next(iter(data_loader))
            augmented_images = batch[0]

            # Display side by side
            print("Original and Augmented images (side by side):")
            fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(8, 12))
            for idx in range(min(3, original_images.shape[0])):
                # Original
                img_orig = original_images[idx]
                img_orig_np = img_orig.cpu().numpy()
                if img_orig_np.shape[0] == 1:
                    img_orig_np = img_orig_np.squeeze(0)
                    axes[idx, 0].imshow(img_orig_np, cmap='gray')
                else:
                    img_orig_np = np.transpose(img_orig_np, (1, 2, 0))
                    axes[idx, 0].imshow(img_orig_np)
                axes[idx, 0].set_title(f"Original Image {idx+1}")
                axes[idx, 0].axis('off')

                # Augmented
                img_aug = augmented_images[idx]
                img_aug_np = img_aug.cpu().numpy()
                if img_aug_np.shape[0] == 1:
                    img_aug_np = img_aug_np.squeeze(0)
                    axes[idx, 1].imshow(img_aug_np, cmap='gray')
                else:
                    img_aug_np = np.transpose(img_aug_np, (1, 2, 0))
                    axes[idx, 1].imshow(img_aug_np)
                axes[idx, 1].set_title(f"Augmented Image {idx+1}")
                axes[idx, 1].axis('off')
            plt.tight_layout()
            plt.show()
                

def get_batch_predictions(models, augmented_inputs, device):
    """
    Get predictions for the augmented inputs.

    Args:
        models (torch.nn.Module or list): Model or list of models to use for predictions.
        augmented_inputs (torch.Tensor): Augmented images.
        device (torch.device): Device to run the models on.
        softmax_application (bool): If True, applies softmax to the prediction.

    Returns:
        torch.Tensor: Batch predictions.
    """
    if isinstance(models, list):
        batch_predictions = []
        for model in models:
            prediction = get_prediction(model, augmented_inputs, device)
            batch_predictions.append(prediction)
    else:
        prediction = get_prediction(models, augmented_inputs, device)
        batch_predictions = [prediction]
    
    batch_predictions = torch.stack(batch_predictions, dim=0)  # Shape: [num_models, batch_size * num_augmentations, num_classes]
    return batch_predictions


def average_predictions(batch_predictions, output_activation=None):
    """
    Average predictions across models and group augmentations back with their respective images.

    Args:
        batch_predictions (torch.Tensor): Batch predictions. Shape: [num_models, batch_size * num_augmentations, num_classes].
        output_activation (str, optional): Activation function to apply to the predictions. 
                                           Options: 'softmax', 'sigmoid', or None. Defaults to None.

    Returns:
        torch.Tensor: Averaged predictions. Shape: [batch_size * num_augmentations, num_classes].
    """
    # Average predictions across models
    averaged_predictions = torch.mean(batch_predictions, dim=0)  # Shape: [batch_size * num_augmentations, num_classes]

    # Apply the specified activation function
    if output_activation == 'softmax':
        averaged_predictions = torch.nn.functional.softmax(averaged_predictions, dim=-1)
    elif output_activation == 'sigmoid':
        averaged_predictions = torch.sigmoid(averaged_predictions)

    return averaged_predictions

In [3]:
def apply_randaugment_and_store_results(
    dataset, models, N, M, num_policies, device, folder_name='savedpolicies',
    image_normalization=False, mean=False, std=False, nb_channels=1, image_size=51,
    output_activation=None, batch_size=None
):
    """
    Apply RandAugment transformations to the data and store the results, one augmentation at a time.
    """
    
    os.makedirs(folder_name, exist_ok=True)

    for i in range(num_policies):
    
        print(f"Applying augmentation policy {i+1}/{num_policies}")
        # Apply augmentation and get augmented images
        apply_augmentations(
            dataset, 1, True, N, M, image_normalization, nb_channels, mean, std, image_size, batch_size=batch_size
        )
        break


In [18]:
apply_randaugment_and_store_results(calibration_dataset, models, 2, 45, 500, device, folder_name=f'/mnt/data/psteinmetz/archive_notebooks/Documents/medMNIST/gps_augment/{im_size}*{im_size}/{dataflag}_calibration_set', image_normalization=True, mean=[.5], std=[.5], image_size=im_size, nb_channels=3, output_activation=activation, batch_size=batch_size)

Applying augmentation policy 1/500
Applying augmentation n : 0
[(<function SolarizeAdd at 0x727acbbb1ee0>, 256, 128), (<function Cutout at 0x727acbbb22a0>, 0, 0.5)]


TypeError: color must be int or single-element tuple

In [22]:
len(calibration_dataset)

125

In [23]:
calibration_dataset.dataset.datasets

[Dataset BreastMNIST of size 224 (breastmnist_224)
     Number of datapoints: 546
     Root location: /home/psteinmetz/.medmnist
     Split: train
     Task: binary-class
     Number of channels: 1
     Meaning of labels: {'0': 'malignant', '1': 'normal, benign'}
     Number of samples: {'train': 546, 'val': 78, 'test': 156}
     Description: The BreastMNIST is based on a dataset of 780 breast ultrasound images. It is categorized into 3 classes: normal, benign, and malignant. As we use low-resolution images, we simplify the task into binary classification by combining normal and benign as positive and classifying them against malignant as negative. We split the source dataset with a ratio of 7:1:2 into training, validation and test set. The source images of 1×500×500 are resized into 1×28×28.
     License: CC BY 4.0,
 Dataset BreastMNIST of size 224 (breastmnist_224)
     Number of datapoints: 78
     Root location: /home/psteinmetz/.medmnist
     Split: val
     Task: binary-class
   