In [None]:
import pickle
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from ml_wireless_classification.base.AdvancedFeatureExtractor import AdvancedFeatureExtractor

# Global dictionary to store feature names and values
feature_dict = {}

def add_feature(name, func, *args):
    """Try to add a feature by ensuring itâ€™s a scalar or individual elements if an array, tuple, or list."""
    try:
        value = func(*args)

        if np.isscalar(value):
            # Add directly if it's a scalar
            if np.iscomplex(value):  # Check if scalar is complex
                feature_dict[f"{name}_magnitude"] = np.abs(value)
                feature_dict[f"{name}_phase"] = np.angle(value)
            else:
                feature_dict[name] = value
        elif isinstance(value, (np.ndarray, list, tuple)):
            # Flatten and iterate over each element if it's an array, list, or tuple
            flattened_values = np.ravel(value)
            for i, sub_value in enumerate(flattened_values):
                if np.isscalar(sub_value):  # Ensure sub_value is scalar
                    if np.iscomplex(sub_value):
                        feature_dict[f"{name}_{i}_magnitude"] = np.abs(sub_value)
                        feature_dict[f"{name}_{i}_phase"] = np.angle(sub_value)
                    else:
                        feature_dict[f"{name}_{i}"] = sub_value
                else:
                    print(f"Warning: Non-scalar value found in '{name}_{i}' and was not added.")
        else:
            print(f"Warning: Feature '{name}' has unsupported type {type(value)} and was not added.")
    except Exception as e:
        print(f"Error computing feature '{name}': {e}")

def extract_features(data):
    features = []
    labels = []
    snrs = []
    # Instantiate AdvancedFeatureExtractor with the complex signal
    feature_extractor = AdvancedFeatureExtractor(np.zeros(128))
    # Retrieve feature methods and names
    feature_methods = feature_extractor.get_features()
    
    for key, signals in data.items():
        mod_type, snr = key
        for signal in signals:
            real_part, imag_part = signal[0], signal[1]
            complex_signal = real_part + 1j * imag_part
            feature_extractor.set_signal(complex_signal)

            # Reset the global feature dictionary
            global feature_dict
            feature_dict = {}

            # Loop through each feature and add it using the add_feature function
            for feature_name, feature_func in feature_methods.items():
                add_feature(feature_name, feature_func)

            # Add SNR as a feature
            feature_dict["SNR"] = snr  # Include SNR as part of the features

            # Append the feature values and label
            features.append(list(feature_dict.values()))
            labels.append(mod_type)

    return np.array(features), labels


# Load the RML2016.10a_dict.pkl file with explicit encoding
with open("../RML2016.10a_dict.pkl", "rb") as f:
    data = pickle.load(f, encoding="latin1")

# Feature extraction for all signals
features, labels = extract_features(data)

# Encode labels for classification
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(labels)

# Split dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, encoded_labels, test_size=0.3, random_state=42)

# Train a single classifier on the entire dataset for multi-class classification
clf = RandomForestClassifier(n_estimators=100, random_state=42)
# print("Training...")
# clf.fit(X_train, y_train)

# # Evaluate accuracy for each SNR level
# unique_snrs = sorted(set(X_test[:, -1]))  # Get unique SNR levels from test set
# accuracy_per_snr = []

# for snr in unique_snrs:
#     # Select samples with the current SNR
#     snr_indices = np.where(X_test[:, -1] == snr)
#     X_snr = X_test[snr_indices]
#     y_snr = y_test[snr_indices]

#     # Predict and calculate accuracy
#     y_pred = clf.predict(X_snr)
#     accuracy = accuracy_score(y_snr, y_pred)
#     accuracy_per_snr.append(accuracy * 100)  # Convert to percentage

#     print(f"SNR: {snr} dB, Accuracy: {accuracy * 100:.2f}%")

# # Plot Recognition Accuracy vs. SNR
# plt.figure(figsize=(10, 6))
# plt.plot(unique_snrs, accuracy_per_snr, 'b-o', label='Recognition Accuracy')
# plt.xlabel("SNR (dB)")
# plt.ylabel("Recognition Accuracy (%)")
# plt.title("Recognition Accuracy vs. SNR for Modulation Classification")
# plt.legend()
# plt.grid(True)
# plt.ylim(0, 100)
# plt.show()

# # Feature importance for the classifier
# feature_names = list(feature_dict.keys())
# importances = clf.feature_importances_
# plt.figure(figsize=(10, 8))
# plt.barh(feature_names, importances, color='skyblue')
# plt.xlabel("Feature Importance")
# plt.title("Feature Importance for Modulation Classification")
# plt.show()

# # Confusion matrix for overall test set
# y_pred_test = clf.predict(X_test)
# conf_matrix = confusion_matrix(y_test, y_pred_test)
# plt.figure(figsize=(12, 10))
# sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", 
#             xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
# plt.xlabel("Predicted Label")
# plt.ylabel("True Label")
# plt.title("Confusion Matrix for Multi-Class Modulation Classification")
# plt.show()

# # Print Classification Report
# print("Classification Report for Modulation Types:")
# print(classification_report(y_test, y_pred_test, target_names=label_encoder.classes_))


In [None]:
def clean_training_data(X, y):
    """
    Cleans X and y by ensuring that all elements are scalars and removing infinities.
    Prints a message if a sequence (list, tuple, array) or non-numeric value is found.
    """
    def check_and_clean_array(arr, array_name):
        cleaned_arr = []
        for i, row in enumerate(arr):
            cleaned_row = []
            for j, value in enumerate(row):
                # Check if value is scalar
                if np.isscalar(value):
                    # Check for infinities or non-finite values
                    if np.isinf(value) or np.isnan(value) or value > np.finfo(np.float32).max:
                        print(f"Warning: {array_name}[{i}][{j}] has an infinity or too large value. Setting to 0.")
                        cleaned_row.append(0)  # Replace infinities or overly large values with 0
                    else:
                        cleaned_row.append(value)
                elif isinstance(value, (list, tuple, np.ndarray)):
                    # If it's a sequence, take the first element as a workaround (optional)
                    sub_value = value[0] if len(value) > 0 else 0
                    if np.isinf(sub_value) or np.isnan(sub_value) or sub_value > np.finfo(np.float32).max:
                        print(f"Warning: {array_name}[{i}][{j}] contains infinity or too large in sequence. Setting to 0.")
                        sub_value = 0
                    cleaned_row.append(sub_value)
                    print(f"Warning: {array_name}[{i}][{j}] is a sequence. Taking the first element.")
                elif isinstance(value, str):
                    # Handle string values with a warning
                    print(f"Warning: {array_name}[{i}][{j}] is a string. Removing and setting to 0.")
                    cleaned_row.append(0)
                else:
                    print(f"Warning: Unexpected data type at {array_name}[{i}][{j}]: {type(value)}")
                    cleaned_row.append(0)  # Default to 0 if type is unexpected
            cleaned_arr.append(cleaned_row)
        
        # Ensure cleaned_arr is a 2D array of fixed-length rows
        max_length = max(len(row) for row in cleaned_arr)
        # Pad rows with zeros if they are shorter than max_length
        cleaned_arr = [row + [0] * (max_length - len(row)) for row in cleaned_arr]
        
        return np.array(cleaned_arr, dtype=float)
    
    # Clean X and y
    X_cleaned = check_and_clean_array(X, "X_train")
    y_cleaned = np.array([elem if np.isscalar(elem) else elem[0] for elem in y], dtype=float)

    return X_cleaned, y_cleaned

def ensure_2d(arr, name):
    """
    Ensures the array is 2D by reshaping if necessary.
    """
    if arr.ndim == 1:
        print(f"Warning: {name} is 1-dimensional. Reshaping to 2D.")
        arr = arr.reshape(-1, 1)
    return arr

# Clean training and test data
X_train, y_train = clean_training_data(X_train, y_train)
X_test, y_test = clean_training_data(X_test, y_test)

# Ensure both X_train and X_test are 2D arrays
X_train = ensure_2d(X_train, "X_train")
X_test = ensure_2d(X_test, "X_test")

print("Training...")
clf.fit(X_train, y_train)

# Check if SNR is actually included as a feature
if X_test.ndim > 1 and X_test.shape[1] > 1:
    # Evaluate accuracy for each SNR level if the SNR column is present
    unique_snrs = sorted(set(X_test[:, -1]))  # Get unique SNR levels from test set
    accuracy_per_snr = []

    for snr in unique_snrs:
        # Select samples with the current SNR
        snr_indices = np.where(X_test[:, -1] == snr)
        X_snr = X_test[snr_indices]
        y_snr = y_test[snr_indices]

        # Predict and calculate accuracy
        y_pred = clf.predict(X_snr)
        accuracy = accuracy_score(y_snr, y_pred)
        accuracy_per_snr.append(accuracy * 100)  # Convert to percentage

        print(f"SNR: {snr} dB, Accuracy: {accuracy * 100:.2f}%")

    # Plot Recognition Accuracy vs. SNR
    plt.figure(figsize=(10, 6))
    plt.plot(unique_snrs, accuracy_per_snr, 'b-o', label='Recognition Accuracy')
    plt.xlabel("SNR (dB)")
    plt.ylabel("Recognition Accuracy (%)")
    plt.title("Recognition Accuracy vs. SNR for Modulation Classification")
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 100)
    plt.show()
else:
    print("SNR feature not found in X_test; skipping SNR-based evaluation.")


In [None]:
# Feature importance for the classifier
feature_names = list(feature_dict.keys())
importances = clf.feature_importances_

# Sort feature importances in descending order
sorted_indices = np.argsort(importances)[::-1]
sorted_feature_names = [feature_names[i] for i in sorted_indices]
sorted_importances = importances[sorted_indices]

# Plot sorted feature importances
plt.figure(figsize=(10, 8))
plt.barh(sorted_feature_names, sorted_importances, color='skyblue')
plt.xlabel("Feature Importance")
plt.title("Feature Importance for Modulation Classification")
plt.gca().invert_yaxis()  # Invert y-axis to show the highest importance at the top
plt.show()

# Confusion matrix for overall test set
y_pred_test = clf.predict(X_test)
conf_matrix = confusion_matrix(y_test, y_pred_test)
plt.figure(figsize=(12, 10))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", 
            xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix for Multi-Class Modulation Classification")
plt.show()

# Print Classification Report
print("Classification Report for Modulation Types:")
print(classification_report(y_test, y_pred_test, target_names=label_encoder.classes_))


In [None]:
# Plot feature importance (if using a tree-based model)
# Plot confusion matrix for SNR > 5 dB subset

# Assuming SNR values are in the last column of X_test
snr_column_index = -1  # Adjust this if SNR is in a different column

# Find indices where SNR > 5
snr_above_5_indices = np.where(X_test[:, snr_column_index] > 5)[0]
X_test_snr_above_5 = X_test[snr_above_5_indices]
y_test_snr_above_5 = y_test[snr_above_5_indices]

# Make predictions on the SNR > 5 dB subset
y_pred_snr_above_5 = clf.predict(X_test_snr_above_5)

# Plot confusion matrix for SNR > 5 dB
conf_matrix_snr_above_5 = confusion_matrix(y_test_snr_above_5, y_pred_snr_above_5)
plt.figure(figsize=(12, 10))
sns.heatmap(conf_matrix_snr_above_5, annot=True, fmt="d", cmap="Blues",
            xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix for Modulation Classification (SNR > 5 dB)")
plt.show()

# Print the classification report for SNR > 5 dB
print("Classification Report for Modulation Types (SNR > 5 dB):")
print(classification_report(y_test_snr_above_5, y_pred_snr_above_5, target_names=label_encoder.classes_))
