In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from sklearn.cluster import KMeans
from plotly.subplots import make_subplots
import math

: 

In [None]:
def star_formation():
    n_stars = 10000
    n_periods = 4

    # Define the positions and brightness of stars at each time period
    positions = np.random.rand(n_stars, n_periods, 3) * 100
    brightness = np.random.rand(n_stars, n_periods)

    # Reshape the data into a 2D array for clustering
    data = np.hstack((positions.reshape(n_stars * n_periods, 3), brightness.reshape(n_stars * n_periods, 1)))

    # Perform clustering on the data
    n_clusters = 5
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(data)

    # Get the cluster labels for each star at each time period
    labels = kmeans.labels_.reshape(n_stars, n_periods)

    # Define the subplot titles
    subplot_titles = ["Time period 1", "Time period 2", "Time period 3", "Time period 4"]

    # Create subplots
    fig = make_subplots(rows=2, cols=2, subplot_titles=subplot_titles, specs=[[{'type': 'scatter3d'}] * 2] * 2)

    layout = go.Layout(
        margin=dict(l=0, r=0, b=0, t=0),
        scene=dict(
            xaxis=dict(title='X'),
            yaxis=dict(title='Y'),
            zaxis=dict(title='Z'),
            aspectmode='data',
            aspectratio=dict(x=1, y=1, z=1),
            bgcolor='black',
            camera=dict(
                up=dict(x=0, y=0, z=1),
            ),
        ),
        paper_bgcolor='black',
    )
    fig.update_layout(layout)

    # Add traces to subplots
    for i in range(n_periods):
        trace = go.Scatter3d(
            x=positions[:, i, 0],
            y=positions[:, i, 1],
            z=positions[:, i, 2],
            mode='markers',
            name="Time period " + str((i + 1)),
            marker=dict(
                color=labels[:, i],
                colorscale='Viridis',
                size=3,
                opacity=1,
            ),
        )
        fig.add_trace(trace, row=math.floor(i / 2) + 1, col=i % 2 + 1)

    # Update subplot titles
    for template in ["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"]:
            fig.update_layout(template=template)
    for i, title in enumerate(subplot_titles):
        fig.update_xaxes(title_text=title, row=math.floor(i / 2) + 1, col=i % 2 + 1)
        fig.update_annotations(selector=dict(text=title), font=dict(color='white'))



    # Find the positions with the highest star formation based on the clustering results
    cluster_sizes = np.zeros(n_clusters)
    for i in range(n_clusters):
        cluster_sizes[i] = np.sum(labels == i)
    max_cluster_idx = np.argmax(cluster_sizes)
    max_cluster_pos = positions[labels == max_cluster_idx]
    text_annotation = go.layout.Annotation(
        x=0.5,
        y=-0.15,
        text=f"Position with highest star formation: {np.mean(max_cluster_pos, axis=0)}",
        showarrow=False,
        font=dict(color='white'),
    )
    # fig.update_layout(annotations=[text_annotation])
    # fig.update_layout(annotations=[f"Position with highest star formation: {np.mean(max_cluster_pos, axis=0)}"])
    
    # Find the cluster with the highest star formation based on the clustering results
    cluster_sizes = np.zeros(n_clusters)
    for i in range(n_clusters):
        cluster_sizes[i] = np.sum(labels == i)
    max_cluster_idx = np.argmax(cluster_sizes)

    # Get the positions and brightness of stars in the cluster with the highest star formation
    max_cluster_positions = positions[labels == max_cluster_idx]
    max_cluster_brightness = brightness[labels == max_cluster_idx]

    # Calculate the centroid and other average properties of the cluster
    centroid = np.mean(max_cluster_positions, axis=0)
    avg_brightness = np.mean(max_cluster_brightness)
    fig.write_html('cluster/star_formation.html')
    data=dict()
    
    data['centroid']=centroid
    data['avg_brightness']=avg_brightness
    print(data)
    return data
    # Print the centroid and average brightness of the cluster
    print(f"Cluster with highest star formation:")
    print(f"Centroid: {centroid}")
    print(f"Average Brightness: {avg_brightness}")

In [None]:
data=star_formation()