In [None]:
import numpy as np
from wirehead import WireheadGenerator 
import os
import pandas as pd
import trimesh
import torch
import gc
import multiprocessing
import time

from scipy.spatial.transform import Rotation as R
from scipy.ndimage import gaussian_filter

import numpy as np
import multiprocessing as mp
import os
from tqdm import tqdm

import sys
sys.path.append('/data/users2/yxiao11/model/satellite_project')
from moduler import *

In [None]:
color_list = [
# Antenna - Using bold primary colors for easy identification
[255, 0, 0, 255],      # Pure Red - antenna
[0, 255, 0, 255],      # Pure Green - antenna
[0, 0, 255, 255],      # Pure Blue - antenna

# Body Top - High contrast, vibrant colors
[255, 255, 0, 255],    # Pure Yellow - body top
[255, 0, 255, 255],    # Pure Magenta - body top
[0, 255, 255, 255],    # Pure Cyan - body top

# Body Bottom - Darker, earthy tones for distinction
[128, 0, 0, 255],      # Dark Red - body bottom
[0, 128, 0, 255],      # Dark Green - body bottom
[0, 0, 128, 255],      # Dark Blue - body bottom

# Lateral Surface - Using a variety of distinct hues to ensure separation
[255, 140, 0, 255],    # Deep Orange - lateral surface
[75, 0, 130, 255],     # Indigo - lateral surface
[139, 69, 19, 255],    # Saddle Brown - lateral surface
[34, 139, 34, 255],    # Forest Green - lateral surface
[128, 0, 128, 255],    # Purple - lateral surface
[0, 139, 139, 255],    # Dark Cyan - lateral surface

# Connectors - Unique pastel color to separate from structure
[165, 42, 42, 255],    # Brown - Connectors (avoids pastels)

# Solar Panels - Keeping distinct tones with enough contrast
[105, 105, 105, 255],  # Dim Grey - Panel
[30, 144, 255, 255],   # Dodger Blue - Panel
[218, 165, 32, 255]   # Goldenrod - Panel   
]

color_to_material = {
    tuple(color[:3]): idx + 1 for idx, color in enumerate(color_list)  # Exclude alpha (last element)
}

In [None]:
def get_material_id(data_type):
    if data_type == "Pristine":
        material_id = 1
    elif data_type == "Irradiated":
        material_id = 2
    elif data_type == "mixed":
        material_id = np.random.randint(1,3)
    return material_id

def generate_fake_spectra(data_type,  bands_length = 50):
    
    material_path = '/data/users2/yxiao11/model/satellite_project/material_spectral/'
    file_names = os.listdir(material_path) # use 7 material, one don't know what it is
    my_material_paths = []

    my_dict_names = [name[:-5] for name in file_names]

    # bands_length = 50
    fake_spectra = {}


    for material in my_dict_names:
        path = os.path.join(material_path, material+'.xlsx')
        
        s_names = pd.ExcelFile(path).sheet_names
                
        if material == 'Thermal_paints': #8
            material_id = get_material_id(data_type)
            color_index = [10,11,12,13,14,15,4,7]
            for i, s_n in enumerate(s_names):
                mt = pd.read_excel(path, sheet_name = s_n).iloc[:, material_id].tolist()[0:1024]
                fake_spectra[color_index[i]] = list_to_numpy_with_mean(mt)[::bands_length]        
        
        if material == 'Kevlar': #1
            material_id = get_material_id(data_type)
            color_index = [5]
            for i, s_n in enumerate(s_names):
                mt = pd.read_excel(path, sheet_name = s_n).iloc[:, material_id].tolist()[0:1024]
                fake_spectra[color_index[i]] = list_to_numpy_with_mean(mt)[::bands_length]        
        
        if material == 'Polymers': #2
            material_id = get_material_id(data_type)
            color_index = [6,8]
            for i, s_n in enumerate(s_names):
                mt = pd.read_excel(path, sheet_name = s_n).iloc[:, material_id].tolist()[0:1024]
                fake_spectra[color_index[i]] = list_to_numpy_with_mean(mt)[::bands_length]   
                
        if material == 'Metals': #2
            material_id = get_material_id(data_type)
            color_index = [9,16]
            for i, s_n in enumerate(s_names):
                mt = pd.read_excel(path, sheet_name = s_n).iloc[:, material_id].tolist()[0:1024]
                fake_spectra[color_index[i]] = list_to_numpy_with_mean(mt)[::bands_length]   
                
        if material == 'Coverglasses': #3
            material_id = get_material_id(data_type)
            color_index = [17,18,19]
            for i, s_n in enumerate(s_names):
                mt = pd.read_excel(path, sheet_name = s_n).iloc[:, material_id].tolist()[0:1024]
                fake_spectra[color_index[i]] = list_to_numpy_with_mean(mt)[::bands_length]
                
        if material == 'Nomex': #3
            material_id = get_material_id(data_type)
            color_index = [1]
            for i, s_n in enumerate(s_names):
                mt = pd.read_excel(path, sheet_name = s_n).iloc[:, material_id].tolist()[0:1024]
                fake_spectra[color_index[i]] = list_to_numpy_with_mean(mt)[::bands_length]
                
        if material == 'Glass_Fiber_Reinforced_Polymer': #3
            material_id = get_material_id(data_type)
            color_index = [2]
            for i, s_n in enumerate(s_names):
                mt = pd.read_excel(path, sheet_name = s_n).iloc[:, material_id].tolist()[0:1024]
                fake_spectra[color_index[i]] = list_to_numpy_with_mean(mt)[::bands_length]
                
        if material == 'Tedlar': #3
            material_id = get_material_id(data_type)
            color_index = [3]
            for i, s_n in enumerate(s_names):
                mt = pd.read_excel(path, sheet_name = s_n).iloc[:, material_id].tolist()[0:1024]
                fake_spectra[color_index[i]] = list_to_numpy_with_mean(mt)[::bands_length]
                
                
    return dict(sorted(fake_spectra.items(), key=lambda x: x[0]))

def make_satellite(color_list):
    
    a,b,c = np.random.choice(np.arange(3,19),3,replace=False)
    color1 = color_list[np.random.randint(0,3)]
    color2 = color_list[a]
    color3 = color_list[b]
    color4 = color_list[c]
    
    # 📌 Cylinder Body Parameters
    body_height = 3.0
    body_radius = 0.8
    sections = 20  # Number of radial divisions

    # 📌 Create the cylindrical body
    body = trimesh.creation.cylinder(radius=body_radius, height=body_height, sections=sections)
    body.visual.vertex_colors = color1  # Apply uniform color

    # 📌 Create the antenna
    antenna = trimesh.creation.icosphere(subdivisions=2, radius=0.4)
    antenna.apply_translation([0, 0, body_height / 2 + 0.5])  # Position above the body
    antenna.visual.vertex_colors = color2  # Apply color

    # 📌 Create the connectors
    connector_size = [1.0, 0.3, 0.3]  # [length, width, height]
    connector1 = trimesh.creation.box(extents=connector_size)
    connector1.apply_translation([body_radius + 0.5, 0, 0])
    connector1.visual.vertex_colors = color3

    connector2 = trimesh.creation.box(extents=connector_size)
    connector2.apply_translation([-body_radius - 0.5, 0, 0])
    connector2.visual.vertex_colors = color3

    connectors = trimesh.util.concatenate([connector1, connector2])

    # 📌 Create the solar panels
    panel_size = [3.5, 2.0, 0.01]  # [length, width, thickness]
    solar_panel1 = trimesh.creation.box(extents=panel_size)
    solar_panel1.apply_translation([2.5, 0, 0])
    solar_panel1.visual.vertex_colors = color4

    solar_panel2 = trimesh.creation.box(extents=panel_size)
    solar_panel2.apply_translation([-2.5, 0, 0])
    solar_panel2.visual.vertex_colors = color4

    solar_panels = trimesh.util.concatenate([solar_panel1, solar_panel2])

    # 📌 Combine all parts into a single model
    satellite = trimesh.util.concatenate([body, antenna, connectors, solar_panels])

    return satellite


def make_scene(satellite, image_size=256):
    # Initialize the pyrender scene
    scene = pyrender.Scene(ambient_light=[1.0, 1.0, 1.0])  # Set uniform ambient light

    # Add the satellite mesh to the scene
    mesh = pyrender.Mesh.from_trimesh(satellite)
    scene.add(mesh)

    # Set flat materials for all primitives in the mesh
    for node in scene.get_nodes():
        if isinstance(node.mesh, pyrender.Mesh):
            for primitive in node.mesh.primitives:
                primitive.material.baseColorFactor = [1.0, 1.0, 1.0, 1.0]  # Flat white color
                primitive.material.emissiveFactor = [1.0, 1.0, 1.0]  # Self-lit effect
                primitive.material.doubleSided = True
    #             primitive.material.alphaMode = "OPAQUE"

    # Add a camera to view the scene
    camera_pose = np.eye(4)
    camera_pose[:3, 3] = [0, 0, 25]  # Position the camera
    camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, znear=0.1, zfar=50.0)
    scene.add(camera, pose=camera_pose)
    
    return scene

def render_with_rotation(scene, angles):
    """
    Render the scene with the model rotated around the origin and save the output image.
    :param scene: Pyrender scene
    :param angles: Tuple of (x, y, z) rotation angles in degrees
    :param filename: Output file name for the rendered image
    """
    

    # Create a rotation matrix for the model
    rotation_matrix = R.from_euler('xyz', angles, degrees=True).as_matrix()

    # Create a 4x4 transformation matrix for the model
    model_pose = np.eye(4)
    model_pose[:3, :3] = rotation_matrix

    # Find the mesh node in the scene (assuming only one node is the model)
    mesh_node = [node for node in scene.get_nodes() if isinstance(node.mesh, pyrender.Mesh)][0]

    # Apply the rotation to the model
    scene.set_pose(mesh_node, pose=model_pose)
    
    return scene

In [None]:
def simulator(image_size, data_type="Pristine"):
    
    satellite = make_satellite(color_list)
    scene = make_scene(satellite, image_size)
    
    os.environ["PYOPENGL_PLATFORM"] = "egl"
    
    angle = np.random.randint(0,360, 3)
    renderer = pyrender.OffscreenRenderer(viewport_width=image_size, viewport_height=image_size)
    scene = render_with_rotation(scene, angle)
    

    material_mask,labels = render_material_mask_with_tolerance(scene, renderer, color_to_material)
    renderer.delete()
    
    # Combine 2D image and spectra to create a spectral cube
    image_shape = (image_size, image_size)  # Example dimensions
    
    fake_spectra = generate_fake_spectra(data_type, bands_length=20)
    spectral_cube = create_spectral_cube(image_shape, material_mask, fake_spectra)
    # Apply optical convolution
    
    
    #######################

    # Generate random k and b for each sample
    k = np.random.uniform(1, 3)  # Example range for k
    b = np.random.uniform(3, 5)      # Example range for b
    n_slices = len(fake_spectra[1])
    # Compute kernel sizes based on the linear formula
    kernel_sizes = (k * np.arange(n_slices) + b).astype(int)
    kernel_sizes[kernel_sizes % 2 == 0] += 1  # Ensure odd kernel sizes

    #######################
    
        # Convert kernel sizes to corresponding sigmas
    sigmas = kernel_sizes / 2.5  # Adjust this scaling factor as needed   

    blurred_cube = np.stack(
        [gaussian_filter(spectral_cube[:, :, j], 
                                       sigma=sigmas[j], 
                                       mode="mirror") for j in range(n_slices)], 
        axis=-1
    )
          
        
    renderer = pyrender.OffscreenRenderer(viewport_width=800, viewport_height=800)
    _,labels = render_material_mask_with_tolerance(scene, renderer, color_to_material)
    labels = np.array(labels)
    label_index = np.unique(labels)-1

    # # one-hot 19*1
    # label = np.zeros(19)
    # label[label_index] = 1

    # prob 19*2
    label = np.tile(np.array([0, 1]), (19, 1))
    label[label_index] = np.array([1, 0])
#     renderer.delete()

    crop_size = 32
    center_x, center_y = image_size // 2, image_size // 2  # Center coordinates

    # Compute the cropping indices
    start_x, end_x = center_x - crop_size // 2, center_x + crop_size // 2
    start_y, end_y = center_y - crop_size // 2, center_y + crop_size // 2
    
#     print('done')
    return blurred_cube[start_x:end_x,start_y:end_y].transpose(2, 0, 1), label

In [None]:
# for i in range(1000):
#     blurred_cube, label = simulator(100, 'mixed') 
    
#     np.save(f'/data/users2/yxiao11/model/satellite_project/database/mixed/blur_cube/{i}.npy', blurred_cube)
#     np.save(f'/data/users2/yxiao11/model/satellite_project/database/mixed/label/{i}.npy', label)

In [None]:
# ---- Worker Function ----
def simulator_worker(i):
    blurred_cube, label = simulator(100, 'mixed')

    # Overwrite the file at index (i % 1000)
    index = i % 1000

    # Save the data
    np.save(f'/data/users2/yxiao11/model/satellite_project/database/mixed/blur_cube/{i}.npy', blurred_cube)
    np.save(f'/data/users2/yxiao11/model/satellite_project/database/mixed/label/{i}.npy', label)

#     return index  # Optional: to keep track of what was saved

In [None]:
num_processes = 50
total_iterations = 10000  # or while True if you want it to run forever

with mp.Pool(processes=num_processes) as pool:
    list(tqdm(pool.imap_unordered(simulator_worker, range(total_iterations)), total=total_iterations))

print("Data generation complete.")