In [1]:
from qutree import *

def parabula(x):
    return np.sum(x**2)

def Err_xyz(x):
    return parabula(x)

def Err_spherical(q):
    x = spherical_to_xyz(q)
    return parabula(x)

N = 20
r = 1
f = 3

class Objective:
    def __init__(self, Err, linspace, q_to_xyz = lambda x : x):
        self.Err = Err
        self.q_to_xyz = q_to_xyz
        self.linspace = linspace

    def __call__(self, x):
        q = self.xyz_to_q(x)
        return self.Err(q)
    
    def plot(self, G):
        plot_tn_xyz(G, self.Err, self.q_to_xyz)

#O = Objective(Err_spherical, spherical_linspace(N), spherical_to_xyz)
O = Objective(Err_xyz, [linspace(0., 1., N = N)] * f)
G = balanced_tree(f, r, N)
G = tn_grid(G, O.linspace)
O.plot(G)

In [2]:
def ttnopt_step(G, F):
    G = G.copy()
    for edge in sweep(G):
        if (is_leaf(edge)):
            continue
        pre = pre_edges(G, edge)
        pre = permute_to_back(pre, flip(edge))
        pre_grids = collect(G, pre, 'grid')
        next = maxvol_grids(pre_grids, F)
        G[edge[0]][edge[1]]['grid'] = next
    build_node_grid(G)
    return G

O.plot(G)
G2 = ttnopt_step(G, O.Err)
O.plot(G2)

In [3]:
#import plotly.express as px
#df = px.data.gapminder()
#print(df[:25])
#px.scatter_3d(df, x="gdpPercap", y="lifeExp", z="pop", animation_frame="year", animation_group="country",
#           size="pop", color="continent", hover_name="country",
#           log_x=True, size_max=55, range_x=[100,100000], range_y=[25,90])

In [26]:
#df = tn_to_df(G, O.Err)
#df = pd.concat([df['xyz'].apply(pd.Series)], axis=1)

# Get node attributes as a dictionary
node_attributes = nx.get_node_attributes(G, 'grid')
node_attributes = {k: v.grid for k, v in node_attributes.items() if v is not None}

# Convert dictionary to DataFrame
df = pd.DataFrame(list(node_attributes.items()), columns=['node', 'grid'])

# Create a new DataFrame with the exploded values and reset the index
df = df.explode('grid').reset_index()

df['f'] = df['grid'].apply(lambda x: O.Err(x))


# Remove the 'index' column
df = df.drop(columns='index')

# Convert the 'grid' column into separate columns
df[['x{}'.format(i+1) for i in range(len(df['grid'].iloc[0]))]] = df['grid'].apply(pd.Series)

# Drop the original 'grid' column
df.drop(columns='grid', inplace=True)

# Create a 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=df['x1'],
    y=df['x2'],
    z=df['x3'],
    mode='markers',
    marker=dict(
        size=5,
        color=df['f'],  # Color by the 'f' column
        colorscale='Viridis',  # Choose a color scale
        opacity=0.8
    )
)])

# Add axis labels
fig.update_layout(scene=dict(xaxis_title='x1', yaxis_title='x2', zaxis_title='x3'))

# Show the plot
fig.show()
print(df)

    node         f     x1     x2     x3
0      0  0.331875  0.025  0.325  0.475
1      0  0.336875  0.075  0.325  0.475
2      0  0.346875  0.125  0.325  0.475
3      0  0.361875  0.175  0.325  0.475
4      0  0.381875  0.225  0.325  0.475
..   ...       ...    ...    ...    ...
57     2  1.551875  0.825  0.325  0.875
58     2  1.641875  0.825  0.325  0.925
59     2  1.736875  0.825  0.325  0.975
60     3  1.011875  0.825  0.325  0.475
61     4  1.011875  0.825  0.325  0.475

[62 rows x 5 columns]
