In [11]:
import sys
sys.path.append('/csproject/t3_lzengaf/lzengaf/ICL/src')

In [12]:
from ood_data_gen import gen_opposite_quadrants, gen_random_quadrants
from samplers import get_data_sampler, sample_transformation
from tqdm import tqdm
import plotly.graph_objects as go
import numpy as np

In [13]:
data_sampler = get_data_sampler('gaussian', n_dims=20)
task = "random_quadrants"
func_dict = {
    # "standard": gen_standard,
    "opposite_quadrants": gen_opposite_quadrants,
    "random_quadrants": gen_random_quadrants,
    # "orthogonal": gen_orthogonal,
    # "projection": gen_projection,
    # "expansion": gen_expansion
}

xs, test_xs = func_dict[task](data_sampler, 40, 1, 6)

max_range = max(abs(xs[..., 0].min()), abs(xs[..., 0].max()), abs(xs[..., 1].min()), abs(xs[..., 1].max()),
                abs(xs[..., 2].min()), abs(xs[..., 2].max()), abs(test_xs[..., 0].min()), abs(test_xs[..., 0].max()),
                abs(test_xs[..., 1].min()), abs(test_xs[..., 1].max()), abs(test_xs[..., 2].min()),
                abs(test_xs[..., 2].max()))

In [None]:
fig = go.Figure()

# # Add scatter plot for train data
fig.add_trace(
    go.Scatter3d(x=xs[..., 0].flatten().numpy(),
                 y=xs[..., 1].flatten().numpy(),
                 z=xs[..., 2].flatten().numpy(),
                 mode='markers',
                 marker=dict(size=7, color='blue', opacity=0.6),
                 name='Train Data'))

# Add scatter plot for test data
fig.add_trace(
    go.Scatter3d(x=test_xs[..., 0].flatten().numpy(),
                 y=test_xs[..., 1].flatten().numpy(),
                 z=test_xs[..., 2].flatten().numpy(),
                 mode='markers',
                 marker=dict(size=7, color='red', opacity=0.6),
                 name='Test Data'))

fig.update_layout(title=f'Scatter Plot of Datasets for {task}',
                  scene=dict(xaxis_title='dim 0',
                             yaxis_title='dim 1',
                             zaxis_title='dim 2',
                             xaxis=dict(range=[-max_range - 0.1, max_range + 0.1]),
                             yaxis=dict(range=[-max_range - 0.1, max_range + 0.1]),
                             zaxis=dict(range=[-max_range - 0.1, max_range + 0.1]),
                             camera=dict(eye=dict(x=2, y=2, z=1))),
                  margin=dict(l=0, r=0, b=30, t=30))


In [None]:
# plot with surfaces splitting orthants
x = np.linspace(-max_range - 0.1, max_range + 0.1, 10)
y = np.linspace(-max_range - 0.1, max_range + 0.1, 10)
z = np.linspace(-max_range - 0.1, max_range + 0.1, 10)

xx, yy = np.meshgrid(x, y)
zz = np.zeros_like(xx)
fig.add_trace(go.Surface(x=xx, y=yy, z=zz, colorscale='gray', opacity=0.2, showscale=False))

# Plane along dim 0 - dim 2 (y = 0)
xx, zz = np.meshgrid(x, z)
yy = np.zeros_like(xx)
fig.add_trace(go.Surface(x=xx, y=yy, z=zz, colorscale='gray', opacity=0.2, showscale=False))

# Plane along dim 1 - dim 2 (x = 0)
yy, zz = np.meshgrid(y, z)
xx = np.zeros_like(yy)
fig.add_trace(go.Surface(x=xx, y=yy, z=zz, colorscale='gray', opacity=0.2, showscale=False))