In [None]:
import pickle
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from ml_wireless_classification.base.AdvancedFeatureExtractor import AdvancedFeatureExtractor

# Global dictionary to store feature names and values
feature_dict = {}

def add_feature(name, func, *args):
    """Try to add a feature by checking the shape and ensuring it’s a scalar, array, or tuple with valid content."""
    try:
        value = func(*args)
        
        # Check for scalar
        if np.isscalar(value):
            feature_dict[name] = value
        # Check for single-element numpy array
        elif isinstance(value, np.ndarray) and value.size == 1:
            feature_dict[name] = value.item()
        # Handle numpy arrays and tuples with multiple values
        elif isinstance(value, (np.ndarray, tuple)):
            if len(value) == 0:
                print(f"Warning: Feature '{name}' returned an empty array/tuple and was not added.")
            else:
                for i, sub_value in enumerate(value):
                    feature_dict[f"{name}_{i}"] = sub_value
        # Check if it's a list with valid content
        elif isinstance(value, list):
            if len(value) == 0:
                print(f"Warning: Feature '{name}' returned an empty list and was not added.")
            else:
                for i, sub_value in enumerate(value):
                    feature_dict[f"{name}_{i}"] = sub_value
        else:
            print(f"Warning: Feature '{name}' has an unexpected shape/type and was not added.")
    except Exception as e:
        print(f"Error computing feature '{name}': {e}")

def extract_features(data):
    features = []
    labels = []
    snrs = []
    # Instantiate AdvancedFeatureExtractor with the complex signal
    feature_extractor = AdvancedFeatureExtractor(np.zeros(128))
    # Retrieve feature methods and names
    feature_methods = feature_extractor.get_features()
    
    for key, signals in data.items():
        mod_type, snr = key
        for signal in signals:
            real_part, imag_part = signal[0], signal[1]
            complex_signal = real_part + 1j * imag_part
            feature_extractor.set_signal(complex_signal)

            # Reset the global feature dictionary
            global feature_dict
            feature_dict = {}

            # Loop through each feature and add it using the add_feature function
            for feature_name, feature_func in feature_methods.items():
                add_feature(feature_name, feature_func)

            # Add SNR as a feature
            feature_dict["SNR"] = snr  # Include SNR as part of the features

            # Append the feature values and label
            features.append(list(feature_dict.values()))
            labels.append(mod_type)

    return np.array(features), labels

# Load the RML2016.10a_dict.pkl file with explicit encoding
with open("../RML2016.10a_dict.pkl", "rb") as f:
    data = pickle.load(f, encoding="latin1")

# Feature extraction for all signals
features, labels = extract_features(data)

# Encode labels for classification
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(labels)

# Split dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, encoded_labels, test_size=0.3, random_state=42)

# Train a single classifier on the entire dataset for multi-class classification
clf = RandomForestClassifier(n_estimators=100, random_state=42)
print("Training...")
clf.fit(X_train, y_train)

# Evaluate accuracy for each SNR level
unique_snrs = sorted(set(X_test[:, -1]))  # Get unique SNR levels from test set
accuracy_per_snr = []

for snr in unique_snrs:
    # Select samples with the current SNR
    snr_indices = np.where(X_test[:, -1] == snr)
    X_snr = X_test[snr_indices]
    y_snr = y_test[snr_indices]

    # Predict and calculate accuracy
    y_pred = clf.predict(X_snr)
    accuracy = accuracy_score(y_snr, y_pred)
    accuracy_per_snr.append(accuracy * 100)  # Convert to percentage

    print(f"SNR: {snr} dB, Accuracy: {accuracy * 100:.2f}%")

# Plot Recognition Accuracy vs. SNR
plt.figure(figsize=(10, 6))
plt.plot(unique_snrs, accuracy_per_snr, 'b-o', label='Recognition Accuracy')
plt.xlabel("SNR (dB)")
plt.ylabel("Recognition Accuracy (%)")
plt.title("Recognition Accuracy vs. SNR for Modulation Classification")
plt.legend()
plt.grid(True)
plt.ylim(0, 100)
plt.show()

# Feature importance for the classifier
feature_names = list(feature_dict.keys())
importances = clf.feature_importances_
plt.figure(figsize=(10, 8))
plt.barh(feature_names, importances, color='skyblue')
plt.xlabel("Feature Importance")
plt.title("Feature Importance for Modulation Classification")
plt.show()

# Confusion matrix for overall test set
y_pred_test = clf.predict(X_test)
conf_matrix = confusion_matrix(y_test, y_pred_test)
plt.figure(figsize=(12, 10))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", 
            xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix for Multi-Class Modulation Classification")
plt.show()

# Print Classification Report
print("Classification Report for Modulation Types:")
print(classification_report(y_test, y_pred_test, target_names=label_encoder.classes_))
