In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random

import cv2
import numpy as np
import torch
import torchvision.transforms as tvtransforms

from pybbbc import BBBC021

import holoviews as hv

from pytorch_hcs.vis import set_hv_defaults

set_hv_defaults()

In [None]:
bbbc021 = BBBC021()

In [None]:
im, _ = bbbc021[4000]

In [None]:
def normalize_image(im, new_min, new_max):
    im_max, im_min = im.max(), im.min()

    return (im - im_min) * (new_max - new_min) / (im_max - im_min) + new_min

In [None]:
def random_gamma(im):
    gamma_low, gamma_high = 0.7, 1.3

    # operate across channels independently

    gamma_ims = []

    for channel_idx in range(im.shape[2]):
        im_channel = im[..., channel_idx]

        im_max, im_min = im_channel.max(), im_channel.min()

        gamma = np.random.uniform(gamma_low, gamma_high)

        im_channel = normalize_image(
            normalize_image(im_channel, 0, 1) ** gamma, im_min, im_max
        )

        gamma_ims.append(im_channel[..., np.newaxis])

    return np.concatenate(gamma_ims, axis=-1)

In [None]:
def random_gauss_noise(im):
    sigma_max = 0.025

    sigmas = (
        np.array([0, 0, 0])
        if random.random() > 0.8
        else np.random.uniform(0, sigma_max, 3)
    )

    return np.clip(
        im
        + (
            sigmas[np.newaxis, np.newaxis]
            * np.random.randn(im.shape[0], im.shape[1], 3)
        ),
        0,
        None,
    )

In [None]:
def random_brightness(im):
    brightness = 0.4

    alphas = (
        np.array([1, 1, 1])
        if random.random() > 0.8
        else 1.0 + np.random.uniform(-brightness, brightness, 3)
    )

    return im * alphas[np.newaxis, np.newaxis]

In [None]:
im_bright = random_brightness(im.transpose(1, 2, 0)).transpose(2, 0, 1)

plots = []

cmaps = ["fire", "kg", "kb"]

for channel_idx, im_channel in enumerate(im_bright):
    plot = hv.Image(
        im_channel,
        bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
        label=f"{bbbc021.CHANNELS[channel_idx]}",
    ).opts(cmap=cmaps[channel_idx])
    plots.append(plot)

plots.append(
    hv.RGB(
        im_bright.transpose(1, 2, 0),
        bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
        label="Channel overlay",
    )
)

hv.Layout(plots).cols(2)

In [None]:
im_noise = random_gauss_noise(im.transpose(1, 2, 0)).transpose(2, 0, 1)

plots = []

cmaps = ["fire", "kg", "kb"]

for channel_idx, im_channel in enumerate(im_noise):
    plot = hv.Image(
        im_channel,
        bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
        label=f"{bbbc021.CHANNELS[channel_idx]}",
    ).opts(cmap=cmaps[channel_idx])
    plots.append(plot)

plots.append(
    hv.RGB(
        im_noise.transpose(1, 2, 0),
        bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
        label="Channel overlay",
    )
)

hv.Layout(plots).cols(2)

In [None]:
im_gamma = random_gamma(im.transpose(1, 2, 0)).transpose(2, 0, 1)

plots = []

cmaps = ["fire", "kg", "kb"]

for channel_idx, im_channel in enumerate(im_gamma):
    plot = hv.Image(
        im_channel,
        bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
        label=f"{bbbc021.CHANNELS[channel_idx]}",
    ).opts(cmap=cmaps[channel_idx])
    plots.append(plot)

plots.append(
    hv.RGB(
        im_gamma.transpose(1, 2, 0),
        bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
        label="Channel overlay",
    )
)

hv.Layout(plots).cols(2)

In [None]:
plots = []

cmaps = ["fire", "kg", "kb"]

for channel_idx, im_channel in enumerate(im):
    plot = hv.Image(
        im_channel,
        bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
        label=f"{bbbc021.CHANNELS[channel_idx]}",
    ).opts(cmap=cmaps[channel_idx])
    plots.append(plot)

plots.append(
    hv.RGB(
        im.transpose(1, 2, 0),
        bounds=(0, 0, im_channel.shape[1], im_channel.shape[0]),
        label="Channel overlay",
    )
)

hv.Layout(plots).cols(2)