In [None]:
import gif
import matplotlib.pyplot as plt
import plotly.graph_objects as go

import numpy as np
from tqdm.notebook import trange

from pyece import (
    Corners, Point,
    PointInflation, PointShift, PointRotate, RandomUniform, 
    Transformer,
)
from pyece.im import cutpatch

In [None]:
path = "/anvar/public_datasets/luna2016/generated_cubes/bat_32_s_64x64x32_0.npy"
patches = np.load(path)
patches.shape

In [None]:
patch = patches[1]

In [None]:
def generate_new_patch(patch: np.ndarray):
    corners = Corners.product(patch.shape)

    scale = RandomUniform(0.5, 1.5)
    shift = RandomUniform(-7.5, 7.5)
    stretch = RandomUniform(0.5, 1)
    angle = RandomUniform(-np.pi, np.pi)

    augmentator = Transformer(
        PointShift(shift=Point((shift, shift, shift))),
        PointInflation(factor=scale),
        PointInflation(factor=Point((stretch, stretch, stretch))),
        PointRotate(angle=Point((angle, angle, angle))),
    )

    new_corners = augmentator(corners).value
    new_patch = cutpatch(data=patch, grid=patch.shape, corners=new_corners)
    return new_patch

In [None]:
def plot_volume(volume: np.ndarray, title: str):
    x_dim, y_dim, z_dim = volume.shape

    fig = go.Figure(
        frames=[
            go.Frame(
                data=go.Surface(
                    z=(z_dim - 1 - k) * np.ones((x_dim, y_dim)),
                    surfacecolor=volume[..., z_dim - 1 - k],
                    cmin=0, cmax=1
                ),name=str(k)
            ) for k in range(z_dim)
        ]
    )


    # Add data to be displayed before animation starts
    fig.add_trace(
        go.Surface(
            z=(z_dim - 1) * np.ones((x_dim, y_dim)),
            surfacecolor=volume[..., z_dim - 1],
            colorscale='Gray',
            cmin=0, cmax=1,
            colorbar=dict(thickness=20, ticklen=4)
        )
    )

    def frame_args(duration):
        return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
        }

    sliders = [
                {
                    "pad": {"b": 10, "t": 60},
                    "len": 0.9,
                    "x": 0.1,
                    "y": 0,
                    "steps": [
                        {
                            "args": [[f.name], frame_args(0)],
                            "label": str(k),
                            "method": "animate",
                        }
                        for k, f in enumerate(fig.frames)
                    ],
                }
            ]

    # Layout
    fig.update_layout(
        title=title,
        width=800,
        height=800,
        scene=dict(
            zaxis=dict(range=[-1, z_dim], autorange=False),
            aspectratio=dict(x=1, y=1, z=1),
        ),
        updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(50)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                    }, {
                        "args": [[None], frame_args(0)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
        ], sliders=sliders
    )
    return fig

In [None]:
fig = plot_volume(patch, title="Default patch")
fig.show()

In [None]:
new_patch = generate_new_patch(patch)

In [None]:
fig = plot_volume(new_patch, title="Augmented patch")
fig.show()

In [None]:
@gif.frame
def plot_frame(k: int, volume: np.ndarray, title: str):
    x_dim, y_dim, z_dim = volume.shape
    
    fig =go.Figure()
    fig.add_trace(
        go.Surface(
            z=(z_dim - 1 - k) * np.ones((x_dim, y_dim)),
            surfacecolor=volume[..., z_dim - 1 - k],
            colorscale='Gray',
            cmin=0, cmax=1
        )
    )
    fig.update_layout(
        title=title,
        width=800,
        height=800,
        scene=dict(
            zaxis=dict(range=[-1, z_dim], autorange=False),
            aspectratio=dict(x=1, y=1, z=1),
        ),
    )
    return fig

In [None]:
def save_gif(volume: np.ndarray, title: str, path: str):
    frames = []
    for k in trange(volume.shape[-1]):
        frame = plot_frame(k, volume, f"{title}, frame {k}")
        frames.append(frame)
    gif.save(frames, path, duration=100)

In [None]:
save_gif(patch, title="Default patch", path="default_patch.gif")

In [None]:
save_gif(new_patch, title="Augmented patch", path="new_patch.gif")