In [None]:
# ================================================
# AMLS ECG Dataset Exploration - Task 1.1
# ================================================

import numpy as np
import pandas as pd
import struct
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedShuffleSplit

# ==========================================
# 1. Data Loader for .bin files
# ==========================================

def load_bin_file(filepath):
    """Load ECG signals from .bin file."""
    signals = []
    with open(filepath, 'rb') as f:
        while True:
            length_bytes = f.read(4)
            if not length_bytes:
                break
            length = struct.unpack('i', length_bytes)[0]
            signal_bytes = f.read(length * 2)
            signal = np.frombuffer(signal_bytes, dtype=np.int16)
            signals.append(signal)
    return signals

# ==========================================
# 2. Load Dataset
# ==========================================

X_train = load_bin_file('data/X_train.bin')
y_train = pd.read_csv('data/y_train.csv', header=None).iloc[:, 0].values

print(f"Loaded {len(X_train)} ECG signals.")
print(f"Class Distribution: {np.bincount(y_train)}")

# ==========================================
# 3. Analyze Signal Lengths
# ==========================================

lengths = np.array([len(sig) for sig in X_train])

print(f"Signal Lengths: min={lengths.min()}, max={lengths.max()}, mean={lengths.mean():.2f}")

plt.figure(figsize=(8, 5))
plt.hist(lengths, bins=30, color='gray')
plt.title('Distribution of ECG Signal Lengths')
plt.xlabel('Signal Length')
plt.ylabel('Count')
plt.show()

# ==========================================
# 4. Class-wise Statistics
# ==========================================

for cls in range(4):
    signals_cls = [sig for sig, label in zip(X_train, y_train) if label == cls]
    cls_lengths = [len(sig) for sig in signals_cls]
    print(f"Class {cls}:")
    print(f"  Samples: {len(signals_cls)}")
    print(f"  Length - min: {np.min(cls_lengths)}, max: {np.max(cls_lengths)}, mean: {np.mean(cls_lengths):.2f}")
    
    all_values = np.concatenate(signals_cls)
    print(f"  Signal Values - mean: {np.mean(all_values):.2f}, std: {np.std(all_values):.2f}, "
          f"min: {np.min(all_values)}, max: {np.max(all_values)}")
    print()

# ==========================================
# 5. Plot Example Signals per Class
# ==========================================

def plot_example_signals(X, y, num_examples=3):
    plt.figure(figsize=(15, 10))
    for cls in range(4):
        cls_signals = [sig for sig, label in zip(X, y) if label == cls]
        for i in range(min(num_examples, len(cls_signals))):
            plt.subplot(4, num_examples, cls * num_examples + i + 1)
            plt.plot(cls_signals[i])
            plt.title(f'Class {cls} - Example {i+1}')
            plt.axis('off')
    plt.tight_layout()
    plt.show()

plot_example_signals(X_train, y_train)

# ==========================================
# 6. Stratified Train-Validation Split (Reproducible)
# ==========================================

splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, val_idx in splitter.split(X_train, y_train):
    X_train_split = [X_train[i] for i in train_idx]
    y_train_split = y_train[train_idx]
    X_val_split = [X_train[i] for i in val_idx]
    y_val_split = y_train[val_idx]

print("Train Class Distribution:", np.bincount(y_train_split))
print("Validation Class Distribution:", np.bincount(y_val_split))

# ==========================================
# 7. Save Splits for Future Tasks
# ==========================================

np.save('data/train_indices.npy', train_idx)
np.save('data/val_indices.npy', val_idx)
np.save('data/y_train_split.npy', y_train_split)
np.save('data/y_val_split.npy', y_val_split)

print("✅ Saved indices and labels for train/val splits.")
