In [None]:
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import numpy as np

from pyece import (
    RandomUniform, RandomChoice,
    PointShift, PointInflation, PointRotate,
    Transformer, Point
)
from pyece import Corners

In [None]:
def draw_corners(ax, corners, color="r", alpha=1.):
    p1, p2, p3, p4 = corners
    path = np.asarray([p1, p2, p4, p3, p1])
    ax.plot(path[:, 0], path[:, 1], c=color, alpha=alpha)
    ax.scatter([p1[0]], [p1[1]], c=color, marker='o', s=50)
    ax.scatter([p2[0]], [p2[1]], c=color, marker='*', s=50)

In [None]:
corners = Corners.product((10,10))    


scale = RandomUniform(0.5, 1.2)
stretching = RandomUniform(0.5, 1)


rndshift = RandomUniform(-7, 7)
rndangl = RandomUniform(-np.pi, np.pi)

aug = Transformer(
    PointShift((rndshift, rndshift)),
    PointInflation(factor = scale),
    # PointInflation(factor = Point([flip, flip])),
    PointInflation(factor = Point([stretching, stretching])),
    
    
    PointRotate(rndangl),
)

In [None]:
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
draw_corners(ax, corners.value, "red")

for i in range(10):
    augbox = aug(corners).value
    draw_corners(ax, augbox, color = "black", alpha = 0.2)


plt.xlim((-5,15))
plt.ylim((-5,15))
plt.grid()

In [None]:
def plot_corners(fig, box: np.ndarray, color=None):
    idx = [0, 1, 3, 2, 6, 4, 5, 7, 3, 2, 0, 4, 5, 1, 5, 7, 6]
    fig.add_trace(go.Scatter3d(
        x=box[idx, 0], 
        y=box[idx, 1],
        z=box[idx, 2],
        marker=dict(
            size=0,
            color=color,
        ),
        line=dict(
            color=color
        )
    ))

In [None]:
def generate(transformator, point_cloud, n):
    for _ in range(n):
        yield transformator(point_cloud).value

In [None]:
corners = Corners.product((10, 10, 10))

scale = RandomUniform(0.5, 1.5)
shift = RandomUniform(-7.5, 7.5)
stretch = RandomUniform(0.5, 1)
angle = RandomUniform(-np.pi, np.pi)

augmentator = Transformer(
    PointShift(shift=Point((shift, shift, shift))),
    PointInflation(factor=scale),
    PointInflation(factor=Point((stretch, stretch, stretch))),
    PointRotate(angle=Point((angle, angle, angle))),
)

In [None]:
fig = go.Figure()
plot_corners(fig, corners.value, color='red')
for x in generate(augmentator, corners, 3):
    plot_corners(fig, x, color='black')
fig.show()