# Read prediction results from .pkl file

In [None]:
import pickle
import random
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

root_dir = '../rollouts/Fragment/Step-0-100-3-AllTest/'
file_name = '001_120_5_0.4C30.pkl'
case = file_name.split('.')[0]
charge_weight = int(case.split('_')[2])


with open(Path(root_dir) / file_name, "rb") as file:
    rollout_data = pickle.load(file)

init_pos = rollout_data['initial_positions']
pred_pos = rollout_data['predicted_rollout']
gt_pos = rollout_data['ground_truth_rollout']
gt_pos = np.concatenate((init_pos, gt_pos), axis=0)

pred_pos = np.concatenate((init_pos, pred_pos), axis=0)

print(f"{case}, {charge_weight}, shape: {pred_pos.shape}")

In [None]:
import networkx as nx
from scipy.spatial.distance import cdist
from sklearn.neighbors import KDTree
from sklearn.metrics import pairwise_distances
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import trimesh
from matplotlib import cm, colors
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable


MASS_PER_PARTICLE = 0.0024
VELOCITY_SCALE_FACTOR = 100 / 6
MAX_FRAGMENT_SIZE = 100
DIST_THRES = 10.12 

def compute_particle_mask(init_pos, charge_weight):
    thres = np.sqrt(charge_weight) * 150
    center_mask =  (init_pos[:, 0] < thres) & (init_pos[:, 0] > -thres) & (init_pos[:, 1] < thres) & (init_pos[:, 1] > -thres)
    
    return center_mask


def compute_fragment_by_graph(particle_pos, previous_affinity_matrix, dist_thres=10.12, max_fragment_size=100):
    
    # Number of particles
    num_particles = particle_pos.shape[0]
    
    # Compute the current adjacency matrix based on distances
    current_adjacency_matrix = cdist(particle_pos, particle_pos) < dist_thres
    
    # Update the persistent adjacency matrix
    # Element-wise multiplication ensures connections can only go from 1 to 0
    current_adjacency_matrix = previous_affinity_matrix & current_adjacency_matrix
    
    # Create graph from the updated adjacency matrix
    G = nx.from_numpy_array(current_adjacency_matrix)
    
    # Find fragments as connected components
    fragments = list(nx.connected_components(G))
    
    # Filter out fragments that are too large
    fragments = [x for x in fragments if len(x) <= max_fragment_size]
    
    return fragments, current_adjacency_matrix


def compute_fragment_by_tree(particle_pos, dist_thres=10.12, max_fragment_size=100):
    kdt = KDTree(particle_pos)
    indices = kdt.query_radius(particle_pos, r=dist_thres)
    visited = set()
    fragments = []
    particles_in_fragments = set()

    for idx, neighbors in enumerate(indices):
        if idx not in visited and idx not in particles_in_fragments:
            new_fragment = set()
            stack = [idx]
            while stack:
                current = stack.pop()
                if current not in visited and current not in particles_in_fragments:
                    visited.add(current)
                    new_fragment.add(current)
                    stack.extend([n for n in indices[current] if n not in visited and n not in particles_in_fragments])

            if len(new_fragment) <= max_fragment_size:
                fragments.append(new_fragment)
                particles_in_fragments.update(new_fragment)
                
    return fragments

def compute_fragment_property(current_pos, previous_pos, fragments):
    centres, masses, diameters, vels = [], [], [], []
    
    for idx, fragment in enumerate(fragments):
        fragment_pos = current_pos[list(fragment)] 
        fragment_centre = fragment_pos.mean(axis=0)
        fragment_mass = len(fragment)*MASS_PER_PARTICLE

        # calculate spatial size (diameter of the fragment)
        if len(fragment) >= 2:
            distances = pairwise_distances(fragment_pos, fragment_pos)
            fragment_diameter = distances.max()
        else:
            fragment_diameter = 10  # single element diameter

        # calculate fragment speed
        particle_vel = current_pos - previous_pos
        fragment_vels = particle_vel[list(fragment)] * VELOCITY_SCALE_FACTOR
        fragment_vel = np.mean(fragment_vels, axis=0)
                                       
        centres.append(fragment_centre)
        masses.append(fragment_mass)
        diameters.append(fragment_diameter)
        vels.append(fragment_vel)
    
    # Conver list to np array
    centres = np.array(centres)
    masses = np.array(masses)
    diameters = np.array(diameters)
    vels = np.array(vels)
    if vels.shape[0]:
        vels = np.linalg.norm(vels, axis=1)
    
    return centres, masses, diameters, vels

def plot_fragment_by_plt(particle_pos, fragments, fragments_vel):
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(particle_pos[:, 0], particle_pos[:, 1], particle_pos[:, 2], s=0.1, c='grey', alpha=0.1)
    
    if fragments:
        cmap = cm.jet
        norm = colors.Normalize(vmin=np.min(fragments_vel), vmax=np.max(fragments_vel))
    # Loop over fragments and plot each mesh
    for idx, fragment in enumerate(fragments):
        fragment_positions = particle_pos[list(fragment)]
        fragment_vel = fragments_vel[idx]

        if len(fragment) > 3:
            mesh = trimesh.Trimesh(vertices=fragment_positions, process=True)
            hull = mesh.convex_hull

            # Get the vertices and faces from the mesh
            vertices = hull.vertices
            faces = hull.faces
            color = cmap(norm(fragment_vel))

            # Create a Poly3DCollection from the vertices and faces
            mesh_plot = Poly3DCollection(vertices[faces], edgecolor='k', facecolors=color, linewidths=0.1, alpha=0.9)
            ax.add_collection3d(mesh_plot)

    # Set plot limits
    ax.set_box_aspect([np.ptp(a) for a in positions.T])
    ax.set_xlim(particle_pos[:, 0].min(), particle_pos[:, 0].max())
    ax.set_ylim(particle_pos[:, 1].min(), particle_pos[:, 1].max())
    ax.set_zlim(particle_pos[:, 2].min(), particle_pos[:, 2].max())

    # Set labels
    ax.set_title(case)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    # Add color bar
    if fragments:
        ax_colorbar = fig.add_axes([0.9, 0.15, 0.02, 0.7])
        mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
        mappable.set_array(fragments_vel)
        cbar = plt.colorbar(mappable, shrink=0.8, aspect=50, cax=ax_colorbar)
        cbar.set_label('fragment velocity (m/s)')
    
    
def plot_fragment_by_plotly(particle_pos, particle_vel, fragments, xyz_min, xyz_max):
    xmin, ymin, zmin = -300, -300, 0
    xmax, ymax, zmax = 300, 300, 300

    fig = go.Figure(data=[go.Scatter3d(
            x=particle_pos[:, 0],
            y=particle_pos[:, 1],
            z=particle_pos[:, 2],
            mode='markers',
            marker=dict(
                size=1,
                color=particle_vel,  # Use the normalized velocities here
                colorscale='Jet',
                cmin=0,
                cmax=100,
                opacity=0.3
            )
        )])

    # Counter for number of fragments
    num_fragment_particles = 0

    # Loop over fragments
    for i, fragment in enumerate(fragments):
        fragment_positions = particle_pos[list(fragment)]
        fragment_velocities = particle_vel[list(fragment)]
        num_fragment_particles += len(fragment)

        # Check if this fragment satisfies the conditions
        if len(fragment) > min_particles:
            # Compute the average normalized velocity for this fragment
            avg_velocity = np.mean(fragment_velocities)

            # Map the average velocity to a color
            color_rgb = cm.jet(avg_velocity / v_max)[:3] # changed to jet colormap
            color_rgba = f"rgba({color_rgb[0]*255}, {color_rgb[1]*255}, {color_rgb[2]*255}, 0.8)"

            mesh = trimesh.Trimesh(vertices=fragment_positions, process=True)
            hull = mesh.convex_hull

            # Then add the mesh to the figure
            fig.add_trace(go.Mesh3d(
                x=hull.vertices[:, 0],
                y=hull.vertices[:, 1],
                z=hull.vertices[:, 2],
                i=hull.faces[:, 0],
                j=hull.faces[:, 1],
                k=hull.faces[:, 2],
                color=color_rgba,
                intensity=[avg_velocity]*hull.faces.shape[0],
                colorscale='Jet',
                cmin=0,
                cmax=100,
                showscale=True,
            ))
    fragment_mass = num_fragment_particles * 0.0024

    fig.update_scenes(
        xaxis=dict(range=[xmin, xmax]), 
        yaxis=dict(range=[ymin, ymax]), 
        zaxis=dict(range=[zmin, zmax])
    )

    # Update layout with the title
    fig.update_layout(
        autosize=False,
        width=1920,
        height=1080,
        scene=dict(
            xaxis=dict(title='X', title_font=dict(size=22), tickfont=dict(size=16)),
            yaxis=dict(title='Y', title_font=dict(size=22), tickfont=dict(size=16)),
            zaxis=dict(title='Z', title_font=dict(size=22), tickfont=dict(size=16)),
            aspectmode='manual',
            aspectratio=dict(x=1, y=1, z=0.5),
            camera = dict(
                up=dict(x=0, y=0, z=1),  # this is the 'up' direction for the camera
                center=dict(x=0, y=0, z=0),  # this will move the camera itself
                eye=dict(x=1.5, y=1.5, z=0.3)  # this moves the 'eye' of the camera
            ),
        ),
        margin=dict(l=0, r=0, b=0, t=0),  # tight layout
        title=dict(
            text=f"Step: {step:02}, time: {(step+10)*0.06:.3f} ms, fragment mass: {fragment_mass:.3f} kg", 
            x=0.48,
            y=0.75,
            xanchor='center',
            yanchor='top',
            font=dict(
                size=35,  # Adjust the font size here
                family="Courier New, monospace",  # Optional: specify font family
            )
        ),
    )
    
    return fig

In [None]:
current_adjacency_matrix = None

for step in range(1, 34):
    init_pos = gt_pos[0, :]
    last_pos = gt_pos[-1, :]
    xyz_min = last_pos.min(axis=0)
    xyz_max = last_pos.max(axis=0)
    
    current_pos = gt_pos[step, :]
    previous_pos = gt_pos[step-1, :]

    mask = compute_particle_mask(init_pos, charge_weight)
    current_pos_masked = current_pos[mask]
    previous_pos_masked = previous_pos[mask]
    particles_vel = np.linalg.norm(current_pos_masked - previous_pos_masked, axis=1) * VELOCITY_SCALE_FACTOR 
    
    if current_adjacency_matrix is None:
        num_particles = current_pos_masked.shape[0]
        current_adjacency_matrix = np.ones((num_particles, num_particles), dtype=bool)
        
    fragments, current_adjacency_matrix = compute_fragment_by_graph(current_pos_masked, current_adjacency_matrix)

    fig = plot_fragment_by_plotly(current_pos_masked, particles_vel, fragments, xyz_min, xyz_max)

    mode = 'gt'
    file_name = f"{mode}-{step:02}.png"
    out_path = Path(root_dir) / 'fragment'/ case
    out_path.mkdir(parents=True, exist_ok=True)
    save_path = out_path / file_name

    fig.write_image(str(save_path), scale=2)

In [None]:
from PIL import Image
import glob

# List of images
img_paths = glob.glob(str(out_path / f'{mode}*.png'))
img_paths.sort()

# Read in images
imgs = [Image.open(img_path) for img_path in img_paths]

# Create a new image object for the first frame, then append the remaining frames.
imgs[0].save(str(out_path / f'{mode}.gif'), format='GIF', append_images=imgs[1:], save_all=True, duration=500, loop=0)

# Generate images for the fragmentation process

In [None]:
import plotly.graph_objects as go
import plotly.io as pio
import pandas as pd
import trimesh
from scipy.spatial import ConvexHull
import numpy as np
from matplotlib import cm
from pathlib import Path

# Define the conditions for fragment filtering
min_particles = 3
mode = 'pred'

positions = current_pos_masked
particles_vel = np.linalg.norm(current_pos_masked - previous_pos_masked, axis=1) * VELOCITY_SCALE_FACTOR

xmin, ymin, zmin = positions.min(axis=0)
xmax, ymax, zmax = positions.max(axis=0)

v_min = 0
v_max = 100
    
out_dir = f"../rollouts/Fragment/fragmentation/{case}"
Path(out_dir).mkdir(parents=True, exist_ok=True)

fig = go.Figure(data=[go.Scatter3d(
    x=positions[:, 0],
    y=positions[:, 1],
    z=positions[:, 2],
    mode='markers',
    marker=dict(
        size=1,
        color=particles_vel,  # Use the normalized velocities here
        colorscale='Jet',
        cmin=v_min,
        cmax=v_max,
        opacity=0.8
    )
)])

# Counter for number of fragments
num_fragment_particles = 0

# Loop over fragments
for i, fragment in enumerate(fragments):
    fragment_positions = positions[list(fragment)]
    fragment_velocities = particles_vel[list(fragment)]
    num_fragment_particles += len(fragment)

    # Check if this fragment satisfies the conditions
    if len(fragment) > min_particles:
        # Compute the average normalized velocity for this fragment
        avg_velocity = np.mean(fragment_velocities)

        # Map the average velocity to a color
        color_rgb = cm.jet(avg_velocity / v_max)[:3] # changed to jet colormap
        color_rgba = f"rgba({color_rgb[0]*255}, {color_rgb[1]*255}, {color_rgb[2]*255}, 0.8)"

        mesh = trimesh.Trimesh(vertices=fragment_positions, process=True)
        hull = mesh.convex_hull

        # Then add the mesh to the figure
        fig.add_trace(go.Mesh3d(
            x=hull.vertices[:, 0],
            y=hull.vertices[:, 1],
            z=hull.vertices[:, 2],
            i=hull.faces[:, 0],
            j=hull.faces[:, 1],
            k=hull.faces[:, 2],
            color=color_rgba,
            intensity=[avg_velocity]*hull.faces.shape[0],
            colorscale='Jet',
            cmin=v_min,
            cmax=v_max,
            showscale=True,
        ))
fragment_mass = num_fragment_particles * 0.0024

fig.update_scenes(
    xaxis=dict(range=[xmin, xmax]), 
    yaxis=dict(range=[ymin, ymax]), 
    zaxis=dict(range=[zmin, zmax])
)

# Update layout with the title
fig.update_layout(
    autosize=False,
    width=1920,
    height=1080,
    scene=dict(
        xaxis=dict(title='X', title_font=dict(size=22), tickfont=dict(size=16)),
        yaxis=dict(title='Y', title_font=dict(size=22), tickfont=dict(size=16)),
        zaxis=dict(title='Z', title_font=dict(size=22), tickfont=dict(size=16)),
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=0.5),
        camera = dict(
            up=dict(x=0, y=0, z=1),  # this is the 'up' direction for the camera
            center=dict(x=0, y=0, z=0),  # this will move the camera itself
            eye=dict(x=1.5, y=1.5, z=0.3)  # this moves the 'eye' of the camera
        ),
    ),
    margin=dict(l=0, r=0, b=0, t=0),  # tight layout
    title=dict(
        text=f"Step: {step:02}, time: {(step+10)*0.06:.3f} ms, fragment mass: {fragment_mass:.3f} kg", 
        x=0.48,
        y=0.75,
        xanchor='center',
        yanchor='top',
        font=dict(
            size=35,  # Adjust the font size here
            family="Courier New, monospace",  # Optional: specify font family
        )
    ),
)
fig.show()

# file_name = f"{mode}-{step:02}.png"
# save_path = Path(out_dir) / file_name
# fig.write_image(str(save_path), scale=1)

In [None]:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
import trimesh
from matplotlib import cm, colors


positions = particle_pos
fragments_vel = vels_graph
fragments = fragments_tree

norm = colors.Normalize(vmin=np.min(fragments_vel), vmax=np.max(fragments_vel))
cmap = cm.jet

fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], s=0.5, c='grey', alpha=0.3)

# Loop over fragments and plot each mesh
for idx, fragment in enumerate(fragments):
    fragment_positions = positions[list(fragment)]
    fragment_vel = fragments_vel[idx]

    if len(fragment) > 3:
        mesh = trimesh.Trimesh(vertices=fragment_positions, process=True)
        hull = mesh.convex_hull

        # Get the vertices and faces from the mesh
        vertices = hull.vertices
        faces = hull.faces
        color = cmap(norm(fragment_vel))

        # Create a Poly3DCollection from the vertices and faces
        mesh_plot = Poly3DCollection(vertices[faces], edgecolor='k', facecolors=color, linewidths=0.1, alpha=0.9)
        ax.add_collection3d(mesh_plot)

# Set plot limits
ax.set_box_aspect([np.ptp(a) for a in positions.T])
ax.set_xlim(positions[:, 0].min(), positions[:, 0].max())
ax.set_ylim(positions[:, 1].min(), positions[:, 1].max())
ax.set_zlim(positions[:, 2].min(), positions[:, 2].max())

# Set labels
ax.set_title(case)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

# Add color bar
ax_colorbar = fig.add_axes([0.9, 0.15, 0.02, 0.7])
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
mappable.set_array(fragments_vel)
cbar = plt.colorbar(mappable, shrink=0.8, aspect=50, cax=ax_colorbar)
cbar.set_label('fragment velocity (m/s)')

In [None]:
import pandas as pd
import plotly.express as px

data_path = '/home/jovyan/share/8TB-share/qilin/fragment/Ejection.csv'
data = pd.read_csv(data_path)
filtered_data = data.sample(frac=0.1)
filtered_data['Danger'] = filtered_data['Danger'].astype(str)

#
color_discrete_map = {'0': 'green', '1': 'red'}  # Assuming Danger is binary (0 or 1)

fig = px.scatter_3d(filtered_data, x='Centre X', y='Centre Y', z='Centre Z',
                    color='Danger',
                    title='3D Visualization of Fragment Ejection',
                    color_discrete_map=color_discrete_map,  # Directly specify the color sequence
                    opacity=0.5)  # Adjust overall opacity

# Adjust marker size and plot size
fig.update_traces(marker=dict(size=2))  # Adjust marker size
fig.update_layout(width=1620, height=1440)  # Adjust plot size (width x height in pixels)

# Show the plot
fig.show()

# To save the plot as an HTML file which you can open in a web browser
# fig.write_html("fragment_ejection_visualization.html")

In [55]:
import pandas as pd
import plotly.graph_objects as go
import trimesh
import numpy as np

data_path = '/home/jovyan/share/8TB-share/qilin/fragment/Ejection.csv'
data = pd.read_csv(data_path)
filtered_data = data.sample(frac=0.1)
filtered_data['Danger'] = filtered_data['Danger'].astype(str)

# Correcting the warning by operating on a copy or using .loc
dangerous_data = filtered_data[filtered_data['Danger'] == '1'].copy()
mass_thresholds = dangerous_data['Mass'].quantile([0.7, 0.9]).tolist()
# Use .loc for proper assignment if not copying
dangerous_data.loc[:, 'mass_group'] = pd.cut(dangerous_data['Mass'],
                                             bins=[0, mass_thresholds[0], mass_thresholds[1], float('inf')],
                                             labels=['small', 'medium', 'large'])

# Updated function to compute and add mesh for each group
def add_group_mesh(fig, points, color, name):
    if len(points) >= 4:
        try:
            mesh = trimesh.Trimesh(vertices=points).convex_hull
            fig.add_trace(go.Mesh3d(
                x=mesh.vertices[:, 0], y=mesh.vertices[:, 1], z=mesh.vertices[:, 2],
                i=mesh.faces[:, 0], j=mesh.faces[:, 1], k=mesh.faces[:, 2],
                color=color, opacity=0.5, name=name, showlegend=True))
        except Exception as e:
            print(f"Could not create mesh for {name} due to: {e}")

# Visualizing with Plotly
fig = go.Figure()

# Add dangerous regions in order of increasing visibility (large to small)
color_map = {'large': 'red', 'medium': 'orange', 'small': 'yellow'}
for group_label in ['small', 'medium', 'large']:  # Defined order
    color = color_map[group_label]
    group_points = dangerous_data[dangerous_data['mass_group'] == group_label][['Centre X', 'Centre Y', 'Centre Z']].values
    add_group_mesh(fig, group_points, color, f'Dangerous - {group_label.capitalize()} Mass')

# Add the safe region mesh first (for better visibility of danger regions on top)
safe_points = filtered_data[filtered_data['Danger'] == '0'][['Centre X', 'Centre Y', 'Centre Z']].values
add_group_mesh(fig, safe_points, 'green', 'Safe Region')

# Adjust the figure size and layout
fig.update_layout(scene=dict(xaxis_title='X Axis', yaxis_title='Y Axis', zaxis_title='Z Axis', camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))), title='3D Visualization of Fragment Mass and Danger Regions', width=1400, height=1000)

fig.show()


In [45]:
filtered_data['Mass'].quantile([0.7, 0.9]).tolist()

[0.0024, 0.0048]