In [1]:
import numpy as np
import pandas as pd
import torch
import scipy.io as sio
from typing import List
from utils import get_palm_mask_484, get_true_indices, sample_to_events, save_spike_data, get_filename_from_params, load_spike_data
from itertools import product
import os

MAT_PATH = "smarthand_dataset.mat"
RANK_CSV = "res95_ranked_taxels.csv"   # produced by your PCA+Pearson script
SAMPLING_FREQUENCY = 100.0  # 100 frames/sec = 100 Hz
FRAME_DURATION_MS = 1000.0 / SAMPLING_FREQUENCY  # 10 ms per frame
NUM_FRAMES = 1200

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [2]:
def preprocess_data(topn: int = None, num_frames: int = 50, threshold: float = 0.03, session_id: List[int] = [0, 1] , channels: bool = True, down_spike: float = 0.1, encoding: str = "spike", rand_pixels: int = None):
    """
    Preprocess the SmartHand dataset from MATLAB file and rank taxels.

    Args:
        topn: Number of top taxels to select.
        num_frames: Total number of frames per batch.
        threshold: Threshold for delta modulation.
        session_id: List of session IDs to include.
        channels: Whether to flatten spike channels.
        down_spike: Magnitude for down spikes.
        encoding: Type of encoding ("spike", "raw", "hybrid").
    Returns:
        Tuple of (train_data, train_labels, test_data, test_labels)
        
    """
    
    # ==============================================================
    # 1. Load & Normalize
    # ==============================================================
    data = sio.loadmat(MAT_PATH)
    tactile = data['tactile_data'].astype(np.float32)
    baseline = data["threshold"].flatten().astype(np.float32)
    valid = data["valid_flag"].flatten().astype(bool)
    y = data["object_id"].flatten().astype(np.int64)
    sessions = data["session_id"].flatten().astype(np.int64)
    
    # Normalize tactile and baseline values to [0, 1]
    tactile_norm = np.clip((tactile.astype(np.float32)-1500)/(2700-1500), 0.0, 1.0)
    baseline_norm = np.clip((baseline-1500)/(1800-1500), 0.0, 1.0)


    # ==============================================================
    # 2. Filter & Taxel Selection
    # ==============================================================
    valid_mask = np.ones(tactile.shape[0], dtype=bool)
    valid_mask &= valid
    tactile_valid = tactile_norm[valid_mask]
    y_valid = y[valid_mask]
    sessions_valid = sessions[valid_mask]
    palm_indices = get_true_indices(get_palm_mask_484())  # all 484 palm taxels
    # Select hand taxels based on topn or palm mask
    if rand_pixels is not None:
        # Random selection from palm
        if rand_pixels > len(palm_indices):
            raise ValueError(f"rand_pixels={rand_pixels} > available palm taxels ({len(palm_indices)})")
        rng = np.random.default_rng(seed=42)  # reproducible
        rand_idx = rng.choice(palm_indices, size=rand_pixels, replace=False)
        tactile_hand = tactile_valid[:, rand_idx]
        baseline_hand = baseline_norm[rand_idx]
        print(f"Selected {rand_pixels} RANDOM palm taxels")
    elif topn is not None:
        rank_df = pd.read_csv(RANK_CSV)
        topn_idx = rank_df["Orig_idx"].to_numpy()[:topn]
        print(f"Selected top-{topn} taxels: {topn_idx}")
        tactile_hand = tactile_valid[:, topn_idx]
        baseline_hand = baseline_norm[topn_idx]
    else:
        keep_idx = np.where(palm_indices)[0]
        tactile_hand = tactile_valid[:, keep_idx]
        baseline_hand = baseline_norm[keep_idx]
    
    print(f"Shape of tactile_hand before baseline subtraction: {tactile_hand.shape}")  # (num_samples, num_pixels)
    # Subtract baseline
    tactile_hand = tactile_hand - baseline_hand
    
    # ==============================================================
    # 3. Balance per (session, class)
    # ==============================================================
    unique_classes = np.unique(y_valid)
    num_pixels = tactile_hand.shape[1]

    frames_per_sample = NUM_FRAMES if 4 not in session_id else NUM_FRAMES - 100

    tactile_balanced_list = []
    y_balanced_list = []
    for sess in session_id:
        for cls in unique_classes:
            mask = (sessions_valid == sess) & (y_valid == cls)
            data_cls = tactile_hand[mask]
            labels_cls = y_valid[mask]
            # Take the first frames_per_sample frames for this class-session pair
            n_take = min(frames_per_sample, len(data_cls))
            tactile_balanced_list.append(data_cls[:n_take])
            y_balanced_list.append(labels_cls[:n_take])

            print(f"Session {sess}, Class {cls}: Selected {n_take} frames")

    tactile_balanced = np.concatenate(tactile_balanced_list, axis=0)
    y_balanced = np.concatenate(y_balanced_list, axis=0)

    # Print final shapes
    print(f"Shape of tactile_balanced: {tactile_balanced.shape}")  # (num_frames, num_pixels)
    print(f"Shape of y_balanced: {y_balanced.shape}")  # (num_frames,)
    
    
    # ==============================================================
    # 4. Batching
    # ==============================================================
    b_size = num_frames
    num_batches = len(tactile_balanced) // b_size
    tactile_batches = tactile_balanced[:num_batches * b_size].reshape(num_batches, b_size, num_pixels)
    y_batches = y_balanced[:num_batches * b_size:b_size]  # Take first label of each batch

    # Print new batch shapes
    print(f"Number of batches: {num_batches}")
    print(f"Shape of tactile_batches: {tactile_batches.shape}")
    print(f"Shape of y_batches: {y_batches.shape}")

    # ==============================================================
    # 5. Helper: Min-max normalization (per sample)
    # ==============================================================
    def min_max_norm(batch):
        b_min, b_max = batch.min(), batch.max()
        return np.zeros_like(batch) if b_max == b_min else (batch - b_min) / (b_max - b_min)

    # ==============================================================
    # 6. Generate Spike Tensors (only for "spike" and "hybrid")
    # ==============================================================
    spike_tensors_2ch = []
    if encoding in {"spike", "hybrid"}:
        print("Generating 2-channel spike tensors...")
        event_batches = []
        start_time = 0.0
        for batch in tactile_batches:
            events = sample_to_events(batch, start_time, threshold)
            event_batches.append(events)
            start_time = 0.0

        time_window_us = FRAME_DURATION_MS
        max_timestamp_us = num_frames * FRAME_DURATION_MS
        num_time_steps = int(np.ceil(max_timestamp_us / time_window_us))
        num_channels = 2

        for events in event_batches:
            spike_tensor = np.zeros((num_time_steps, num_channels, num_pixels), dtype=np.float32)
            if events.size > 0:
                for x, t, p in events:
                    time_step = min(int(t // time_window_us), num_time_steps - 1)
                    if p > 0:
                        spike_tensor[time_step, 0, x] = 1.0
                    else:
                        spike_tensor[time_step, 1, x] = down_spike
            spike_tensors_2ch.append(spike_tensor)
            
    # ==============================================================
    # 7. Final Tensor Construction by Mode
    # ==============================================================
    final_tensors = []

    if encoding == "spike":
        for spk in spike_tensors_2ch:
            if channels == 1:
                combined = spk[:, 0, :] + spk[:, 1, :]  # (T, F)
            else:
                combined = spk.reshape(num_time_steps, num_channels * num_pixels)  # (T, 2*F)
            final_tensors.append(combined)

    elif encoding == "raw":
        print("Generating normalized raw pressure...")
        for batch in tactile_batches:
            raw_centered = batch - np.mean(batch, axis=0, keepdims=True)
            final_tensors.append(raw_centered.astype(np.float32))

    elif encoding == "hybrid":
        print("Generating hybrid: UP + DOWN + raw (3Ã— features)...")
        for spk, raw_batch in zip(spike_tensors_2ch, tactile_batches):
            # UP + DOWN (flattened or not)
            if channels == 1:
                up_down = spk[:, 0, :] + spk[:, 1, :]  # (T, F)
            else:
                up_down = spk.reshape(num_time_steps, num_channels * num_pixels)  # (T, 2*F)

            # Raw centered around 0 (per-pixel mean subtraction over time)
            raw_centered = raw_batch - np.mean(raw_batch, axis=0, keepdims=True)

            # Concat: [UP+DOWN, raw]
            combined = np.concatenate([up_down, raw_centered], axis=1)  # (T, 2*F) or (T, 3*F)
            final_tensors.append(combined.astype(np.float32))

    # ==============================================================
    # 8. Reorder by Class
    # ==============================================================
    unique_classes = np.unique(y_batches)
    reordered_tensors = []
    reordered_labels = []

    for cls in unique_classes:
        class_idx = [i for i, lbl in enumerate(y_batches) if lbl == cls]
        for i in class_idx:
            reordered_tensors.append(final_tensors[i])
            reordered_labels.append(y_batches[i])

    final_tensors = reordered_tensors
    y_tensors = np.array(reordered_labels)

    print(f"Final: {len(final_tensors)} samples")
    print(f"Sample shape: {final_tensors[0].shape}")
    print(f"Feature dim: {final_tensors[0].shape[1]}")

    return final_tensors, y_tensors

In [6]:
param_grid = {
    "topn": [64, 32],  # Include multiple topn values for iteration
    "num_frames": [25],
    "threshold": [0.01],
    "session_id": [[0, 1]],
    "channels": [1],
    "down_spike": [1.0],
    "encoding": ["spike"],
    "rand_pixels": [None],
}

output_dir = "preprocessed_data"
for params in product(*param_grid.values()):
    params_dict = dict(zip(param_grid.keys(), params))
    print(f"Generating data for params: {params_dict}")
    
    
    filename = get_filename_from_params(params_dict, output_dir)
    if os.path.exists(filename):
        print(f"Data already exists for params {params_dict}: {filename}. Skipping generation.")
        continue
    
   
    # Instantiate dataset
    spike_tensors, y_tensors = preprocess_data(**params_dict)
    
    # Prepare parameters for saving
    save_params = {
        "topn": params_dict["topn"],
        "num_frames": params_dict["num_frames"],
        "threshold": params_dict["threshold"],
        "session_id": params_dict["session_id"],
        "channels": params_dict["channels"],
        "down_spike": params_dict["down_spike"],
        "encoding": params_dict["encoding"],
        "rand_pixels": params_dict["rand_pixels"],
    }
    
    # Save data
    save_spike_data(
        spike_tensors=spike_tensors,
        y=y_tensors,
        output_dir=output_dir,
        params=save_params
    )
   

Generating data for params: {'topn': 64, 'num_frames': 25, 'threshold': 0.01, 'session_id': [0, 1], 'channels': 1, 'down_spike': 1.0, 'encoding': 'spike', 'rand_pixels': None}
Selected top-64 taxels: [114  82  50 178 146  18 210 214 182 246 150 269 149 309 312 313  14 277
 152 560 407 237 406 181 245 205 120 374 377 213 118  19 117 376 270 286
 282 467 375 430 310 344  46 405 278 302  86 284 242 142 342  51 341 345
 408 398 437  54 180  88 462 435 153  83]
Shape of tactile_hand before baseline subtraction: (195072, 64)
Session 0, Class 0: Selected 1200 frames
Session 0, Class 1: Selected 1200 frames
Session 0, Class 2: Selected 1200 frames
Session 0, Class 3: Selected 1200 frames
Session 0, Class 4: Selected 1200 frames
Session 0, Class 5: Selected 1200 frames
Session 0, Class 6: Selected 1200 frames
Session 0, Class 7: Selected 1200 frames
Session 0, Class 8: Selected 1200 frames
Session 0, Class 9: Selected 1200 frames
Session 0, Class 10: Selected 1200 frames
Session 0, Class 11: Se