# Reduced PyTorch to MATLAB Dataset Converter (Inverse Function)

This notebook converts reduced PyTorch `.pt` format back to MATLAB `.mat` format.

**Features:**
- Loads reduced PyTorch dataset from a folder
- Reconstructs full MATLAB structure from reduced data
- Reconstructs EIGENVECTOR_DATA from reduced displacements using indices
- Reconstructs EIGENVALUE_DATA (fills with zeros if not available)
- Reconstructs designs with all panes
- Saves as MATLAB v7.3 `.mat` file

**Note:** This is the inverse function of `matlab_to_reduced_pt.ipynb`


In [None]:
# Import required libraries
import numpy as np
import torch
import h5py
from pathlib import Path
import os
import time

# Custom utilities
import NO_utils
import NO_utils_multiple

print("Libraries imported successfully")


## Configuration

Set your input (reduced PT dataset folder) and output (MATLAB .mat file) paths:


In [None]:
# ============================================================================
# CONFIGURATION - Modify these as needed
# ============================================================================

# Path to reduced PT dataset folder (contains .pt files)
pt_input_folder = r"D:\Research\NO-2D-Metamaterials\data\dispersion_binarized_1_predictions"

# Output folder (will be created with _mat suffix)
output_base_folder = Path(pt_input_folder).parent
output_folder_name = Path(pt_input_folder).name + "_mat"
output_folder = output_base_folder / output_folder_name

# Output MATLAB .mat file path (inside the output folder)
matlab_output_file = output_folder / f"{Path(pt_input_folder).name}.mat"

# ============================================================================

# Convert to Path objects
pt_input_path = Path(pt_input_folder)
matlab_output_path = Path(matlab_output_file)

print(f"PT input folder: {pt_input_path}")
print(f"Output folder: {output_folder}")
print(f"MATLAB output file: {matlab_output_path}")

# Validate input
if not pt_input_path.exists():
    raise FileNotFoundError(f"Input folder does not exist: {pt_input_path}")

# Create output directory if needed
output_folder.mkdir(parents=True, exist_ok=True)


## Inverse Conversion Function

Define a function to convert reduced PT format back to MATLAB format:


In [None]:
def interleave_arrays(arr1, arr2, dim):
    """
    Interleave two arrays along a specified dimension (inverse of split_array).
    
    Parameters:
    -----------
    arr1 : np.ndarray
        First array (elements at even indices)
    arr2 : np.ndarray
        Second array (elements at odd indices)
    dim : int
        Dimension along which to interleave
    
    Returns:
    --------
    np.ndarray : Interleaved array
    """
    if arr1.shape != arr2.shape:
        raise ValueError(f"Arrays must have the same shape. Got {arr1.shape} and {arr2.shape}")
    
    shape = list(arr1.shape)
    if shape[dim] % 2 != 0:
        raise ValueError(f"Dimension {dim} must be even for interleaving")
    
    # Create output shape with doubled dimension
    output_shape = shape.copy()
    output_shape[dim] = shape[dim] * 2
    
    # Create output array
    output = np.zeros(output_shape, dtype=arr1.dtype)
    
    # Create slices for interleaving
    slices_even = [slice(None)] * len(shape)
    slices_odd = [slice(None)] * len(shape)
    slices_even[dim] = slice(None, None, 2)
    slices_odd[dim] = slice(1, None, 2)
    
    output[tuple(slices_even)] = arr1
    output[tuple(slices_odd)] = arr2
    
    return output


def convert_reduced_pt_to_matlab(pt_input_path, matlab_output_path):
    """
    Convert reduced PyTorch dataset back to MATLAB .mat format.
    
    Parameters:
    -----------
    pt_input_path : Path
        Path to the folder containing .pt files
    matlab_output_path : Path
        Path to output .mat file
    
    Returns:
    --------
    dict : Information about the conversion
    """
    print("\n" + "=" * 80)
    print("Converting Reduced PT Dataset to MATLAB Format")
    print("=" * 80)
    
    start_time = time.time()
    
    # Step 1: Load Reduced PT Dataset
    print("\nStep 1: Loading Reduced PT Dataset")
    
    # Load all required files
    displacements_dataset = torch.load(pt_input_path / "displacements_dataset.pt", map_location='cpu')
    reduced_indices = torch.load(pt_input_path / "reduced_indices.pt", map_location='cpu')
    geometries = torch.load(pt_input_path / "geometries_full.pt", map_location='cpu')
    waveforms = torch.load(pt_input_path / "waveforms_full.pt", map_location='cpu')
    wavevectors = torch.load(pt_input_path / "wavevectors_full.pt", map_location='cpu')
    bands_fft = torch.load(pt_input_path / "band_fft_full.pt", map_location='cpu')
    design_params = torch.load(pt_input_path / "design_params_full.pt", map_location='cpu')
    
    # Convert to numpy
    eigenvector_x_real = displacements_dataset.tensors[0].numpy()
    eigenvector_x_imag = displacements_dataset.tensors[1].numpy()
    eigenvector_y_real = displacements_dataset.tensors[2].numpy()
    eigenvector_y_imag = displacements_dataset.tensors[3].numpy()
    
    geometries_np = geometries.numpy()
    wavevectors_np = wavevectors.numpy()
    design_params_np = design_params.numpy()
    
    # Get dimensions
    n_designs = geometries_np.shape[0]
    design_res = geometries_np.shape[1]
    n_wavevectors = wavevectors_np.shape[1]
    
    # Determine n_bands from bands_fft
    n_bands = bands_fft.shape[0]
    
    print(f"  Loaded dataset dimensions:")
    print(f"    n_designs: {n_designs}")
    print(f"    design_res: {design_res}")
    print(f"    n_wavevectors: {n_wavevectors}")
    print(f"    n_bands: {n_bands}")
    print(f"    n_reduced_samples: {len(reduced_indices)}")
    
    # Step 2: Reconstruct Full EIGENVECTOR_DATA
    print("\nStep 2: Reconstructing Full EIGENVECTOR_DATA")
    
    # Initialize full eigenvector arrays (fill with zeros for missing entries)
    EIGENVECTOR_DATA_x_full = np.zeros((n_designs, n_wavevectors, n_bands, design_res, design_res), 
                                        dtype=np.complex64)
    EIGENVECTOR_DATA_y_full = np.zeros((n_designs, n_wavevectors, n_bands, design_res, design_res), 
                                        dtype=np.complex64)
    
    # Place reduced eigenvectors at correct indices
    for sample_idx, (d_idx, w_idx, b_idx) in enumerate(reduced_indices):
        # Convert indices to int if they're tensors
        d_idx = int(d_idx) if isinstance(d_idx, torch.Tensor) else int(d_idx)
        w_idx = int(w_idx) if isinstance(w_idx, torch.Tensor) else int(w_idx)
        b_idx = int(b_idx) if isinstance(b_idx, torch.Tensor) else int(b_idx)
        
        # Reconstruct complex eigenvectors
        eigenvector_x = eigenvector_x_real[sample_idx] + 1j * eigenvector_x_imag[sample_idx]
        eigenvector_y = eigenvector_y_real[sample_idx] + 1j * eigenvector_y_imag[sample_idx]
        
        # Place in full array
        EIGENVECTOR_DATA_x_full[d_idx, w_idx, b_idx, :, :] = eigenvector_x
        EIGENVECTOR_DATA_y_full[d_idx, w_idx, b_idx, :, :] = eigenvector_y
    
    print(f"  Reconstructed EIGENVECTOR_DATA_x shape: {EIGENVECTOR_DATA_x_full.shape}")
    print(f"  Reconstructed EIGENVECTOR_DATA_y shape: {EIGENVECTOR_DATA_y_full.shape}")
    
    # Step 3: Combine x and y eigenvectors into single array
    print("\nStep 3: Combining x and y eigenvectors")
    
    # Reshape to (n_designs, n_wavevectors, n_bands, 2*design_res*design_res)
    EIGENVECTOR_DATA_x_flat = EIGENVECTOR_DATA_x_full.reshape(n_designs, n_wavevectors, n_bands, -1)
    EIGENVECTOR_DATA_y_flat = EIGENVECTOR_DATA_y_full.reshape(n_designs, n_wavevectors, n_bands, -1)
    
    # Interleave x and y components
    n_dof = 2 * design_res * design_res
    EIGENVECTOR_DATA_combined = np.zeros((n_designs, n_wavevectors, n_bands, n_dof), dtype=np.complex64)
    
    # Interleave manually
    EIGENVECTOR_DATA_combined[:, :, :, 0::2] = EIGENVECTOR_DATA_x_flat
    EIGENVECTOR_DATA_combined[:, :, :, 1::2] = EIGENVECTOR_DATA_y_flat
    
    # Transpose to match MATLAB format: (n_designs, n_eig, n_wv, n_dof)
    EIGENVECTOR_DATA = EIGENVECTOR_DATA_combined.transpose(0, 2, 1, 3)
    
    print(f"  Combined EIGENVECTOR_DATA shape: {EIGENVECTOR_DATA.shape}")
    
    # Step 4: Reconstruct EIGENVALUE_DATA
    print("\nStep 4: Reconstructing EIGENVALUE_DATA")
    
    # Initialize with zeros (since we don't have eigenvalue data in reduced format)
    EIGENVALUE_DATA = np.zeros((n_designs, n_bands, n_wavevectors), dtype=np.float32)
    
    # Fill with NaN to indicate missing data
    EIGENVALUE_DATA[:] = np.nan
    
    # Transpose to match MATLAB format: (n_designs, n_eig, n_wv)
    # Already in correct format
    
    print(f"  Reconstructed EIGENVALUE_DATA shape: {EIGENVALUE_DATA.shape}")
    print(f"  Note: EIGENVALUE_DATA filled with NaN (not available in reduced format)")
    
    # Step 5: Reconstruct designs with all panes
    print("\nStep 5: Reconstructing designs with all panes")
    
    # Original designs had 3 panes, but we only have the first pane (elastic modulus)
    # Duplicate it for all panes
    n_panes = 3
    designs_full = np.zeros((n_designs, n_panes, design_res, design_res), dtype=np.float32)
    designs_full[:, 0, :, :] = geometries_np  # First pane (elastic modulus)
    designs_full[:, 1, :, :] = geometries_np  # Duplicate for second pane
    designs_full[:, 2, :, :] = geometries_np  # Duplicate for third pane
    
    print(f"  Reconstructed designs shape: {designs_full.shape}")
    
    # Step 6: Reconstruct WAVEVECTOR_DATA
    print("\nStep 6: Reconstructing WAVEVECTOR_DATA")
    
    # Transpose to match MATLAB format: (n_designs, 2, n_wv)
    WAVEVECTOR_DATA = wavevectors_np.transpose(0, 2, 1)
    
    print(f"  Reconstructed WAVEVECTOR_DATA shape: {WAVEVECTOR_DATA.shape}")
    
    # Step 7: Reconstruct WAVEFORM_DATA
    print("\nStep 7: Reconstructing WAVEFORM_DATA")
    
    # WAVEFORM_DATA should be (n_designs, n_wavevectors, design_res, design_res)
    # We have waveforms for the first design, duplicate for all designs
    WAVEFORM_DATA = np.zeros((n_designs, n_wavevectors, design_res, design_res), dtype=np.float32)
    
    # Use waveforms from first design (they're the same for all designs)
    waveforms_np = waveforms.numpy()
    for d_idx in range(n_designs):
        WAVEFORM_DATA[d_idx, :, :, :] = waveforms_np
    
    print(f"  Reconstructed WAVEFORM_DATA shape: {WAVEFORM_DATA.shape}")
    
    # Step 8: Reconstruct const dictionary
    print("\nStep 8: Reconstructing const dictionary")
    
    # Infer const from available data
    # Note: N_wv grid dimensions are not available in reduced format, so we'll use a default
    # The original had [25, 13] but we'll use [n_wavevectors, 1] as a placeholder
    const = {
        'N_pix': np.array([[float(design_res)]], dtype=np.float64),  # (1, 1)
        'N_ele': np.array([[1.0]], dtype=np.float64),  # (1, 1)
        'N_eig': np.array([[float(n_bands)]], dtype=np.float64),  # (1, 1)
        'N_wv': np.array([[float(n_wavevectors)], [1.0]], dtype=np.float64),  # (2, 1) - placeholder
        'a': np.array([[1.0]], dtype=np.float64),  # Default value
        'E_max': np.array([[1.0]], dtype=np.float64),  # Default value
        'E_min': np.array([[0.01]], dtype=np.float64),  # Default value
        'poisson_max': np.array([[0.3]], dtype=np.float64),  # Default value
        'poisson_min': np.array([[0.3]], dtype=np.float64),  # Default value
        'rho_max': np.array([[1.0]], dtype=np.float64),  # Default value
        'rho_min': np.array([[1.0]], dtype=np.float64),  # Default value
        't': np.array([[1.0]], dtype=np.float64),  # Default value
        'sigma_eig': np.array([[1e-2]], dtype=np.float64),  # Default value
        'design_scale': np.array([['linear']], dtype=object),  # Default value
        'symmetry_type': np.array([['none']], dtype=object),  # Default value
        'eigenvector_dtype': np.array([['single']], dtype=object),  # Default value
        'isSaveEigenvectors': np.array([[1.0]], dtype=np.float64),
        'isSaveKandM': np.array([[0.0]], dtype=np.float64),
        'isSaveMesh': np.array([[0.0]], dtype=np.float64),
        'isUseGPU': np.array([[0.0]], dtype=np.float64),
        'isUseImprovement': np.array([[1.0]], dtype=np.float64),
        'isUseParallel': np.array([[1.0]], dtype=np.float64),
        'isUseSecondImprovement': np.array([[0.0]], dtype=np.float64),
        'design': np.zeros((n_panes, design_res, design_res), dtype=np.float64),  # Placeholder
        'wavevectors': wavevectors_np[0, :, :].T  # Wavevectors from first design
    }
    
    print(f"  Reconstructed const dictionary with {len(const)} keys")
    
    # Step 9: Prepare other metadata
    print("\nStep 9: Preparing metadata")
    
    N_struct = np.array([[float(n_designs)]], dtype=np.float64)
    imag_tol = np.array([[1e-6]], dtype=np.float64)  # Default value
    rng_seed_offset = np.array([[0.0]], dtype=np.float64)  # Default value
    
    # Step 10: Save as MATLAB v7.3 format
    print("\nStep 10: Saving as MATLAB v7.3 format")
    
    # Prepare dataset dictionary
    dataset = {
        'WAVEVECTOR_DATA': WAVEVECTOR_DATA.astype(np.float32),
        'EIGENVALUE_DATA': EIGENVALUE_DATA.astype(np.float32),
        'EIGENVECTOR_DATA': EIGENVECTOR_DATA.astype(np.complex64),
        'designs': designs_full.astype(np.float32),
        'design_params': design_params_np.astype(np.float64),
        'N_struct': N_struct.astype(np.float64),
        'imag_tol': imag_tol.astype(np.float64),
        'rng_seed_offset': rng_seed_offset.astype(np.float64),
    }
    
    # Save const as a struct
    with h5py.File(matlab_output_path, 'w') as f:
        # Save regular arrays
        for key, value in dataset.items():
            if key == 'EIGENVECTOR_DATA':
                # Save complex array as structured array (compound dtype) - MATLAB v7.3 format
                # Convert to float32 (single precision) to match original MATLAB format
                EIGENVECTOR_DATA_real = value.real.astype(np.float32)
                EIGENVECTOR_DATA_imag = value.imag.astype(np.float32)
                
                # Create structured array with compound dtype (matches MATLAB format)
                structured_dtype = np.dtype([('real', np.float32), ('imag', np.float32)])
                EIGENVECTOR_DATA_structured = np.empty(value.shape, dtype=structured_dtype)
                EIGENVECTOR_DATA_structured['real'] = EIGENVECTOR_DATA_real
                EIGENVECTOR_DATA_structured['imag'] = EIGENVECTOR_DATA_imag
                
                # Create dataset with structured dtype (compound datatype)
                dset = f.create_dataset(
                    'EIGENVECTOR_DATA',
                    data=EIGENVECTOR_DATA_structured,
                    dtype=structured_dtype
                )
                # Add MATLAB_class attribute to indicate it's a single-precision complex array
                dset.attrs['MATLAB_class'] = np.bytes_(b'single')
            else:
                f.create_dataset(key, data=value)
        
        # Save const as a struct group
        const_grp = f.create_group('const')
        for key, value in const.items():
            if isinstance(value, np.ndarray):
                if value.dtype == object:
                    # Handle string arrays
                    dt = h5py.special_dtype(vlen=str)
                    dset = const_grp.create_dataset(key, value.shape, dtype=dt)
                    dset[:] = value.astype(str)
                else:
                    const_grp.create_dataset(key, data=value)
            else:
                const_grp.attrs[key] = value
    
    elapsed_time = time.time() - start_time
    file_size = matlab_output_path.stat().st_size / (1024 * 1024)
    
    print(f"  Saved to: {matlab_output_path}")
    print(f"  File size: {file_size:.2f} MB")
    print(f"  Conversion completed in {elapsed_time:.2f} seconds")
    
    return {
        'output_path': matlab_output_path,
        'file_size_mb': file_size,
        'elapsed_time': elapsed_time,
        'n_designs': n_designs,
        'n_wavevectors': n_wavevectors,
        'n_bands': n_bands
    }

print("Inverse conversion function defined.")


## Run Conversion

Execute the conversion function:


In [None]:
# Run the conversion
result = convert_reduced_pt_to_matlab(pt_input_path, matlab_output_path)

print("\n" + "=" * 80)
print("Conversion Summary")
print("=" * 80)
print(f"Output file: {result['output_path']}")
print(f"File size: {result['file_size_mb']:.2f} MB")
print(f"Conversion time: {result['elapsed_time']:.2f} seconds")
print(f"n_designs: {result['n_designs']}")
print(f"n_wavevectors: {result['n_wavevectors']}")
print(f"n_bands: {result['n_bands']}")
print("=" * 80)


## Verification

Verify that the reconstructed MATLAB file can be loaded and has the correct structure:


In [None]:
# Verify the reconstructed file
print("=" * 80)
print("Verifying Reconstructed MATLAB File")
print("=" * 80)

# Load the reconstructed file
import h5py

with h5py.File(matlab_output_path, 'r') as f:
    print("\nKeys in reconstructed file:")
    print(list(f.keys()))
    
    print("\nShapes of main arrays:")
    if 'EIGENVECTOR_DATA' in f:
        ev_shape = f['EIGENVECTOR_DATA']['real'].shape
        print(f"  EIGENVECTOR_DATA: {ev_shape}")
    
    if 'EIGENVALUE_DATA' in f:
        eigval_shape = f['EIGENVALUE_DATA'].shape
        print(f"  EIGENVALUE_DATA: {eigval_shape}")
    
    if 'WAVEVECTOR_DATA' in f:
        wv_shape = f['WAVEVECTOR_DATA'].shape
        print(f"  WAVEVECTOR_DATA: {wv_shape}")
    
    if 'designs' in f:
        designs_shape = f['designs'].shape
        print(f"  designs: {designs_shape}")
    
    if 'const' in f:
        print(f"\nconst keys: {list(f['const'].keys())}")

# Also try loading with NO_utils to verify compatibility
print("\n" + "=" * 80)
print("Testing compatibility with NO_utils.extract_data()")
print("=" * 80)

try:
    # Create a temporary directory with the reconstructed file
    import tempfile
    import shutil
    
    temp_dir = tempfile.mkdtemp(prefix="verify_reconstructed_")
    temp_mat_path = Path(temp_dir) / matlab_output_path.name
    
    # Copy the reconstructed file to temp directory
    shutil.copy2(matlab_output_path, temp_mat_path)
    
    # Try to extract data using NO_utils
    (designs, design_params, n_designs, n_panes, design_res,
     WAVEVECTOR_DATA, WAVEFORM_DATA, n_dim, n_wavevectors,
     EIGENVALUE_DATA, n_bands, EIGENVECTOR_DATA_x,
     EIGENVECTOR_DATA_y, const, N_struct,
     imag_tol, rng_seed_offset) = NO_utils.extract_data(temp_dir)
    
    print("\n✓ Successfully loaded with NO_utils.extract_data()!")
    print(f"  n_designs: {n_designs}")
    print(f"  n_panes: {n_panes}")
    print(f"  design_res: {design_res}")
    print(f"  n_wavevectors: {n_wavevectors}")
    print(f"  n_bands: {n_bands}")
    print(f"  EIGENVECTOR_DATA_x shape: {EIGENVECTOR_DATA_x.shape}")
    print(f"  EIGENVECTOR_DATA_y shape: {EIGENVECTOR_DATA_y.shape}")
    
    # Clean up
    shutil.rmtree(temp_dir, ignore_errors=True)
    
    print("\n" + "=" * 80)
    print("✓ Verification successful! The reconstructed file is compatible.")
    print("=" * 80)
    
except Exception as e:
    print(f"\n✗ Error during verification: {e}")
    import traceback
    traceback.print_exc()
    # Clean up
    shutil.rmtree(temp_dir, ignore_errors=True)
