In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Lambda, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# 데이터 경로 설정
DATA_PATH = "../../data/DB6/DB6_s1_a/S1_D1_T1.mat"

def load_ninapro_data(file_path):
    """
    NinaPro 데이터를 로드하고 전처리하는 함수
    """
    try:
        # .mat 파일 로드
        data = sio.loadmat(file_path)
        
        # 일반적인 NinaPro DB6 구조에 따른 키 확인
        print("Available keys in the data:", list(data.keys()))
        
        # 일반적으로 사용되는 키들 (실제 데이터에 따라 조정 필요)
        if 'emg' in data:
            emg_data = data['emg']
        elif 'data' in data:
            emg_data = data['data']
        else:
            # 키를 찾아서 EMG 데이터 추출
            data_keys = [k for k in data.keys() if not k.startswith('__')]
            emg_data = data[data_keys[0]]
        
        if 'stimulus' in data:
            labels = data['stimulus'].flatten()
        elif 'restimulus' in data:
            labels = data['restimulus'].flatten()
        elif 'glove' in data:
            labels = data['glove']
            if labels.ndim > 1:
                labels = labels[:, 0]  # 첫 번째 열 사용
        else:
            # 라벨 데이터 찾기
            label_keys = [k for k in data.keys() if 'stimulus' in k.lower() or 'label' in k.lower()]
            if label_keys:
                labels = data[label_keys[0]].flatten()
            else:
                # 두 번째로 큰 배열을 라벨로 가정
                data_keys = [k for k in data.keys() if not k.startswith('__')]
                labels = data[data_keys[1]].flatten() if len(data_keys) > 1 else np.zeros(emg_data.shape[0])
        
        print(f"EMG data shape: {emg_data.shape}")
        print(f"Labels shape: {labels.shape}")
        print(f"Unique labels: {np.unique(labels)}")
        
        return emg_data, labels
    
    except Exception as e:
        print(f"Error loading data: {e}")
        # 샘플 데이터 생성 (실제 데이터가 없을 경우)
        print("Generating sample data for demonstration...")
        n_samples = 10000
        n_channels = 12
        emg_data = np.random.randn(n_samples, n_channels) * 0.1
        # EMG 신호처럼 보이도록 노이즈 추가
        for i in range(n_channels):
            emg_data[:, i] += np.sin(np.linspace(0, 100*np.pi, n_samples)) * 0.05
        labels = np.random.randint(0, 7, n_samples)  # 0-6 클래스
        return emg_data, labels

def emg_to_spike_encoding(emg_data, threshold_method='adaptive', time_steps=50):
    """
    EMG 신호를 스파이크 트레인으로 변환하는 함수
    """
    n_samples, n_channels = emg_data.shape
    spike_data = np.zeros((n_samples, time_steps, n_channels))
    
    for sample_idx in range(n_samples):
        for channel_idx in range(n_channels):
            signal = emg_data[sample_idx, channel_idx]
            
            if threshold_method == 'adaptive':
                # 적응적 임계값 설정
                threshold = np.std(signal) * 2.0
            elif threshold_method == 'fixed':
                # 고정 임계값
                threshold = 0.1
            else:
                # 백분위수 기반
                threshold = np.percentile(np.abs(signal), 75)
            
            # Rate coding: 신호 강도에 비례한 스파이크 빈도
            spike_rate = np.abs(signal) / (threshold + 1e-8)
            spike_rate = np.clip(spike_rate, 0, 1)
            
            # 포아송 과정으로 스파이크 생성
            spikes = np.random.random(time_steps) < spike_rate
            spike_data[sample_idx, :, channel_idx] = spikes.astype(np.float32)
    
    return spike_data

def preprocess_data_for_snn(emg_data, labels, window_size=200, overlap=100, time_steps=50):
    """
    SNN을 위한 EMG 데이터 전처리
    """
    # 레이블이 0인 rest 구간 제거 (선택사항)
    non_zero_mask = labels != 0
    emg_data = emg_data[non_zero_mask]
    labels = labels[non_zero_mask]
    
    # 윈도우 기반 시퀀스 생성
    windowed_sequences = []
    windowed_labels = []
    
    step_size = window_size - overlap
    
    for i in range(0, len(emg_data) - window_size + 1, step_size):
        window = emg_data[i:i+window_size]
        window_label = labels[i:i+window_size]
        
        # 윈도우 내에서 가장 빈번한 라벨 사용
        unique_labels, counts = np.unique(window_label, return_counts=True)
        dominant_label = unique_labels[np.argmax(counts)]
        
        # 윈도우를 요약하여 하나의 특징 벡터로 변환
        features = []
        for channel in range(window.shape[1]):
            channel_data = window[:, channel]
            features.extend([
                np.mean(channel_data),
                np.std(channel_data),
                np.max(channel_data) - np.min(channel_data),  # 범위
                np.mean(np.abs(np.diff(channel_data))),       # 평균 변화율
            ])
        
        windowed_sequences.append(features)
        windowed_labels.append(dominant_label)
    
    # EMG 특징을 스파이크로 인코딩
    emg_features = np.array(windowed_sequences)
    spike_data = emg_to_spike_encoding(emg_features, time_steps=time_steps)
    
    return spike_data, np.array(windowed_labels)

class LIFNeuron(tf.keras.layers.Layer):
    """
    Leaky Integrate-and-Fire (LIF) 뉴런 구현
    """
    def __init__(self, units, tau=20.0, threshold=1.0, reset_value=0.0, 
                 refractory_period=2, **kwargs):
        super(LIFNeuron, self).__init__(**kwargs)
        self.units = units
        self.tau = tau  # 막 시정수
        self.threshold = threshold  # 스파이크 임계값
        self.reset_value = reset_value  # 리셋 전압
        self.refractory_period = refractory_period  # 불응기
        
        # 학습 가능한 가중치
        self.w = None
        self.b = None
        
    def build(self, input_shape):
        # 가중치 초기화
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer='glorot_uniform',
            name='lif_weights',
            trainable=True
        )
        self.b = self.add_weight(
            shape=(self.units,),
            initializer='zeros',
            name='lif_bias',
            trainable=True
        )
        super(LIFNeuron, self).build(input_shape)
    
    def call(self, inputs, training=None):
        batch_size = tf.shape(inputs)[0]
        time_steps = tf.shape(inputs)[1]
        
        # 상태 변수 초기화
        membrane_potential = tf.zeros((batch_size, self.units))
        output_spikes = tf.TensorArray(tf.float32, size=time_steps)
        refractory_counter = tf.zeros((batch_size, self.units))
        
        # 시간 스텝별 처리
        for t in range(time_steps):
            current_input = inputs[:, t, :]  # (batch_size, input_dim)
            
            # 시냅스 전류 계산
            synaptic_current = tf.matmul(current_input, self.w) + self.b
            
            # 불응기 체크
            not_refractory = tf.cast(refractory_counter <= 0, tf.float32)
            
            # 막전위 업데이트 (LIF 동역학)
            membrane_potential = (
                membrane_potential * (1 - 1/self.tau) + 
                synaptic_current * not_refractory
            )
            
            # 스파이크 생성
            spikes = tf.cast(membrane_potential >= self.threshold, tf.float32)
            
            # 막전위 리셋
            membrane_potential = tf.where(
                spikes > 0, 
                tf.ones_like(membrane_potential) * self.reset_value,
                membrane_potential
            )
            
            # 불응기 카운터 업데이트
            refractory_counter = tf.where(
                spikes > 0,
                tf.ones_like(refractory_counter) * self.refractory_period,
                tf.maximum(refractory_counter - 1, 0)
            )
            
            output_spikes = output_spikes.write(t, spikes)
        
        # 출력 조합
        spike_output = tf.transpose(output_spikes.stack(), [1, 0, 2])  # (batch, time, units)
        
        return spike_output
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'units': self.units,
            'tau': self.tau,
            'threshold': self.threshold,
            'reset_value': self.reset_value,
            'refractory_period': self.refractory_period,
        })
        return config

class SpikeReadout(tf.keras.layers.Layer):
    """
    스파이크를 최종 출력으로 변환하는 레이어
    """
    def __init__(self, readout_method='rate', **kwargs):
        super(SpikeReadout, self).__init__(**kwargs)
        self.readout_method = readout_method
    
    def call(self, inputs):
        if self.readout_method == 'rate':
            # 스파이크 빈도를 계산
            return tf.reduce_mean(inputs, axis=1)  # 시간 축 평균
        elif self.readout_method == 'sum':
            # 총 스파이크 수
            return tf.reduce_sum(inputs, axis=1)
        elif self.readout_method == 'last':
            # 마지막 시간 스텝
            return inputs[:, -1, :]
        else:
            # 가중 평균 (최근 시간에 더 높은 가중치)
            time_steps = tf.shape(inputs)[1]
            weights = tf.linspace(0.1, 1.0, time_steps)
            weights = weights / tf.reduce_sum(weights)
            return tf.reduce_sum(inputs * weights[None, :, None], axis=1)
    
    def get_config(self):
        config = super().get_config()
        config.update({'readout_method': self.readout_method})
        return config

def create_snn_model(input_shape, num_classes, hidden_units=[128, 64], 
                     tau=20.0, threshold=1.0, readout_method='rate'):
    """
    Spiking Neural Network 모델 생성
    """
    time_steps, n_features = input_shape
    
    inputs = Input(shape=input_shape, name='spike_input')
    
    # SNN 레이어들
    x = inputs
    for i, units in enumerate(hidden_units):
        x = LIFNeuron(
            units=units,
            tau=tau,
            threshold=threshold,
            name=f'lif_layer_{i}'
        )(x)
        
        # 드롭아웃 추가 (스파이크에 적용)
        if i < len(hidden_units) - 1:  # 마지막 레이어가 아닌 경우
            x = Dropout(0.2)(x)
    
    # 출력 LIF 레이어
    output_spikes = LIFNeuron(
        units=num_classes,
        tau=tau * 0.5,  # 출력 레이어는 더 빠른 시정수
        threshold=threshold * 0.8,
        name='output_lif'
    )(x)
    
    # 스파이크를 최종 출력으로 변환
    readout = SpikeReadout(readout_method=readout_method)(output_spikes)
    
    # 소프트맥스 적용
    outputs = tf.keras.activations.softmax(readout)
    
    model = Model(inputs, outputs, name='SNN_EMG_Classifier')
    
    # 모델 컴파일
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

def plot_spike_raster(spike_data, sample_idx=0, max_neurons=20):
    """
    스파이크 래스터 플롯 생성
    """
    sample_spikes = spike_data[sample_idx]  # (time_steps, n_channels)
    time_steps, n_channels = sample_spikes.shape
    
    plt.figure(figsize=(12, 6))
    
    for neuron in range(min(n_channels, max_neurons)):
        spike_times = np.where(sample_spikes[:, neuron] > 0)[0]
        plt.scatter(spike_times, [neuron] * len(spike_times), 
                   s=2, alpha=0.7, label=f'Channel {neuron}' if neuron < 5 else "")
    
    plt.xlabel('Time Steps')
    plt.ylabel('EMG Channel')
    plt.title(f'Spike Raster Plot - Sample {sample_idx}')
    plt.grid(True, alpha=0.3)
    if max_neurons <= 5:
        plt.legend()
    plt.tight_layout()
    plt.show()

def plot_training_history(history):
    """
    SNN 학습 과정 시각화
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 정확도 플롯
    axes[0, 0].plot(history.history['accuracy'], label='Training Accuracy', linewidth=2)
    axes[0, 0].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
    axes[0, 0].set_title('SNN Model Accuracy', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 손실 플롯
    axes[0, 1].plot(history.history['loss'], label='Training Loss', linewidth=2)
    axes[0, 1].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
    axes[0, 1].set_title('SNN Model Loss', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # 마지막 10 에포크 정확도
    start_epoch = max(0, len(history.history['accuracy']) - 10)
    epochs_range = range(start_epoch, len(history.history['accuracy']))
    
    axes[1, 0].plot(epochs_range, history.history['accuracy'][start_epoch:], 
                   label='Training Accuracy', linewidth=2, marker='o')
    axes[1, 0].plot(epochs_range, history.history['val_accuracy'][start_epoch:], 
                   label='Validation Accuracy', linewidth=2, marker='s')
    axes[1, 0].set_title('Last 10 Epochs - Accuracy', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # SNN 특성 정보
    axes[1, 1].text(0.5, 0.5, 
                   'SNN Architecture\n\n• LIF Neurons\n• Spike-based Processing\n• Temporal Dynamics\n• Bio-inspired Computing\n• Event-driven Processing', 
                   ha='center', va='center', transform=axes[1, 1].transAxes,
                   fontsize=11, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen", alpha=0.5))
    axes[1, 1].set_xlim(0, 1)
    axes[1, 1].set_ylim(0, 1)
    axes[1, 1].set_title('SNN Architecture Info', fontsize=14, fontweight='bold')
    
    plt.suptitle('Spiking Neural Network Training Results', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names):
    """
    혼동 행렬 시각화
    """
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Greens', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('SNN Model - Confusion Matrix', fontsize=16, fontweight='bold')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.show()

def main():
    print("=== NinaPro EMG Spiking Neural Network Classification ===")
    
    # 1. 데이터 로드
    print("\n1. Loading data...")
    emg_data, labels = load_ninapro_data(DATA_PATH)
    
    # 2. SNN용 데이터 전처리
    print("\n2. Preprocessing data for SNN...")
    time_steps = 50
    X, y = preprocess_data_for_snn(emg_data, labels, 
                                   window_size=200, overlap=100, 
                                   time_steps=time_steps)
    
    print(f"Spike data shape: {X.shape}")  # (n_samples, time_steps, n_features)
    print(f"Labels shape: {y.shape}")
    print(f"Number of classes: {len(np.unique(y))}")
    print(f"Average spike rate: {np.mean(X):.4f}")
    
    # 3. 라벨 인코딩
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)
    num_classes = len(np.unique(y_encoded))
    class_names = [f"Gesture {i}" for i in range(num_classes)]
    
    # 4. 데이터 분할
    print("\n3. Splitting data...")
    X_temp, X_test, y_temp, y_test = train_test_split(
        X, y_encoded, test_size=0.15, random_state=42, stratify=y_encoded
    )
    
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=0.176, random_state=42, stratify=y_temp
    )
    
    print(f"Train set: {X_train.shape[0]} samples")
    print(f"Validation set: {X_val.shape[0]} samples")
    print(f"Test set: {X_test.shape[0]} samples")
    
    # 5. 원-핫 인코딩
    y_train_cat = to_categorical(y_train, num_classes)
    y_val_cat = to_categorical(y_val, num_classes)
    y_test_cat = to_categorical(y_test, num_classes)
    
    # 6. SNN 모델 생성
    print("\n4. Creating SNN model...")
    input_shape = (time_steps, X_train.shape[-1])
    model = create_snn_model(
        input_shape=input_shape,
        num_classes=num_classes,
        hidden_units=[128, 64],
        tau=20.0,
        threshold=1.0,
        readout_method='rate'
    )
    
    print(model.summary())
    
    # 7. 스파이크 패턴 시각화
    print("\n5. Visualizing spike patterns...")
    plot_spike_raster(X_train, sample_idx=0, max_neurons=12)
    
    # 8. 콜백 설정
    callbacks = [
        EarlyStopping(monitor='val_accuracy', patience=25, restore_best_weights=True, verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=15, min_lr=1e-6, verbose=1)
    ]
    
    # 9. 모델 학습
    print("\n6. Training SNN model for 100 epochs...")
    history = model.fit(
        X_train, y_train_cat,
        batch_size=32,
        epochs=100,
        validation_data=(X_val, y_val_cat),
        callbacks=callbacks,
        verbose=1
    )
    
    # 10. 최종 평가
    print("\n7. Evaluating SNN model...")
    
    # 검증 세트 평가
    val_loss, val_accuracy = model.evaluate(X_val, y_val_cat, verbose=0)
    print(f"Final Validation Accuracy: {val_accuracy:.4f}")
    
    # 테스트 세트 평가
    test_loss, test_accuracy = model.evaluate(X_test, y_test_cat, verbose=0)
    print(f"Final Test Accuracy: {test_accuracy:.4f}")
    
    # 예측
    y_pred_proba = model.predict(X_test)
    y_pred = np.argmax(y_pred_proba, axis=1)
    
    # 11. 결과 시각화
    print("\n8. Plotting results...")
    
    # 학습 과정 시각화
    plot_training_history(history)
    
    # 혼동 행렬 시각화
    plot_confusion_matrix(y_test, y_pred, class_names)
    
    # 분류 리포트
    print("\n9. Classification Report:")
    print(classification_report(y_test, y_pred, target_names=class_names))
    
    # 최종 결과 요약
    print("\n=== Final SNN Results Summary ===")
    print(f"Training Accuracy: {max(history.history['accuracy']):.4f}")
    print(f"Validation Accuracy: {val_accuracy:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Total Epochs Trained: {len(history.history['accuracy'])}")
    print(f"Model Parameters: {model.count_params():,}")
    print(f"Average Spike Rate: {np.mean(X_train):.4f}")
    
    return model, history, label_encoder

if __name__ == "__main__":
    model, history, label_encoder = main()