# ViT / VViT Transformer Models — Phase-Spectrum Based Recognition

### Reference
This notebook implements the transformer-based automatic modulation recognition models presented by  
**Bhatti, Sidra Ghayour; Taj, Imtiaz Ahmad; Ullah, Mohsin; and Bhatti, Aamer Iqbal (2024)** —  
*“Transformer-Based Models for Intrapulse Modulation Recognition of Radar Waveforms.”*  
*Engineering Applications of Artificial Intelligence*, **136**, 108989.  
**BibTeX:** [@RN181]


### 1 Setup

#### 1.1 Imports

In [None]:
import os
import time
import gc
import h5py
import joblib
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from ViT import MainModel

#### 1.2 Device Selection

In [None]:
# Device Selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

#### 1.3 Data Input

In [None]:
# Data Loading
def load_algorithm_snr_h5s(root_folder, mod_types):
    """
    Loads .h5 spectrogram files from a specific algorithm's snr_X folder,
    filtered by modulation type (FM, PM, HYBRID).

    Parameters:
    - root_folder (str): Path to the snr_X directory (e.g., .../preprocessed_images/vit/snr_0)
    - mod_types (list): List of modulation categories to include, e.g., ['FM', 'PM']

    Returns:
    - X: np.ndarray of images
    - y: np.ndarray of labels (modulation names as strings)
    """
    X = []
    y = []

    for mod_type in mod_types:
        mod_path = os.path.join(root_folder, mod_type)
        if not os.path.exists(mod_path):
            print(f"⚠️ Warning: {mod_path} does not exist. Skipping.")
            continue

        print(f"📂 Loading from {mod_type}...")
        files = [f for f in os.listdir(mod_path) if f.endswith(".h5")]

        for file in tqdm(files, desc=f"   {mod_type}", unit="file"):
            mod_name = file[:-3]  # Strip '.h5'
            file_path = os.path.join(mod_path, file)

            try:
                with h5py.File(file_path, "r") as h5f:
                    if mod_name not in h5f:
                        print(f"⚠️ Warning: No top-level group named '{mod_name}' in {file_path}")
                        continue
                    group = h5f[mod_name]
                    for key in group.keys():
                        img = np.array(group[key])
                        X.append(img)
                        y.append(mod_name)
            except Exception as e:
                print(f"❌ Failed to load {file_path}: {e}")

    X = np.array(X)
    y = np.array(y)
    # Cleanup after loading
    gc.collect()
    return X, y


#### 1.4 Data Loader

In [None]:
# Data Loader
def prepare_dataloader(X, y, batch_size=32, shuffle=False, num_workers=2):
    """
    Prepares a DataLoader from X and y, keeping data on CPU until batches are moved to GPU.

    Parameters:
    - X: np.ndarray or torch.Tensor of input data
    - y: np.ndarray or torch.Tensor of labels
    - batch_size: int, size of each batch
    - shuffle: bool, whether to shuffle the data
    - num_workers: int, number of workers for loading data

    Returns:
    - DataLoader object
    """
    if isinstance(X, np.ndarray):
        X = torch.tensor(X, dtype=torch.uint8)
    elif not isinstance(X, torch.Tensor):
        raise TypeError("Input X must be a NumPy array or PyTorch tensor")

    if isinstance(y, np.ndarray):
        y = torch.tensor(y, dtype=torch.long)
    elif not isinstance(y, torch.Tensor):
        raise TypeError("Labels y must be a NumPy array or PyTorch tensor")

    # Ensure X has four dimensions (N, C, H, W)
    if X.ndim == 3:  # If (N, H, W), add a channel dimension
        X = X.unsqueeze(1)  # (N, 1, H, W)
    elif X.ndim == 4 and X.shape[-1] in [1, 3]:  # (N, H, W, C) case
        X = X.permute(0, 3, 1, 2)  # Convert to (N, C, H, W)

    dataset = TensorDataset(X, y)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
    )
    # Cleanup after creating DataLoader
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    return loader


### 2 Training

#### 2.1 Training Setup

In [None]:
# Training Function
def train_model(
    model,
    train_loader,
    device,
    criterion,
    optimizer,
    scheduler=None,
    epochs=10,
    patience=3,
    min_delta=0.0,
    model_save_path=None
):
    model.to(device)
    best_loss = float("inf")
    patience_counter = 0
    loss_history = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True, dynamic_ncols=True)

        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = inputs.float()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

            # Cleanup after each batch
            # del inputs, labels, outputs, loss
            # gc.collect()
            # if torch.cuda.is_available():
            #     torch.cuda.empty_cache()

        avg_loss = total_loss / len(train_loader)
        loss_history.append(avg_loss)
        print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")

        if scheduler:
            scheduler.step()

        # Save best model
        if avg_loss < best_loss - min_delta:
            best_loss = avg_loss
            patience_counter = 0
            if model_save_path:
                torch.save(model.state_dict(), model_save_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    # Final cleanup after training
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return loss_history

##### 2.1.1 Loss Curve

In [None]:
# Plot Loss Curve
def plot_loss_curve(loss_history, output_path, title="Training Loss Over Epochs"):
    epochs = len(loss_history)
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, epochs + 1), loss_history, marker="o", label="Training Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.grid()
    plt.savefig(output_path + "loss_curve.png")
    plt.close()
    # Cleanup after plotting
    gc.collect()

##### 2.1.2 Conf Matrix

In [None]:
# Confusion Matrix Display
def display_confusion_matrix(
    model, data_loader, device, output_path, class_names=None, title="Confusion Matrix"
):
    model.to(device)
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            output_class = model(inputs)
            preds = torch.argmax(output_class, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            # Cleanup after each batch
            del inputs, labels, output_class, preds
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    cm = confusion_matrix(all_labels, all_preds)
    num_classes = cm.shape[0]
    cm_normalized = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True) * 100

    if class_names is None:
        class_names = [str(i) for i in range(num_classes)]

    plt.figure(figsize=(max(10, num_classes * 0.8), max(8, num_classes * 0.6)))
    im = plt.imshow(cm_normalized, interpolation="nearest", cmap="Blues")
    plt.title(title, fontsize=14)
    plt.colorbar(im, label="Percentage")
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, class_names, rotation=45, ha="right", va="top", fontsize=max(8, 12 - num_classes // 5))
    plt.yticks(tick_marks, class_names, fontsize=max(8, 12 - num_classes // 5))

    thresh = cm_normalized.max() / 2.0
    for i in range(num_classes):
        for j in range(num_classes):
            plt.text(
                j, i, f"{cm_normalized[i, j]:.1f}",
                ha="center", va="center",
                color="white" if cm_normalized[i, j] > thresh else "black",
                fontsize=max(8, 12 - num_classes // 5),
            )

    plt.ylabel("True Label", fontsize=12, labelpad=10)
    plt.xlabel("Predicted Label", fontsize=12, labelpad=10)
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2 + num_classes * 0.005)
    plt.savefig(output_path + "conf_matrix.png")
    plt.close()
    # Cleanup after plotting
    gc.collect()

### 3 Testing

In [None]:
# Evaluation Function
def evaluate_model(model, test_loader, label_encoder, device, output_path):
    model.to(device)
    model.eval()
    y_true = []
    y_pred = []
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = inputs.float()
            output_class = model(inputs)
            preds = torch.argmax(output_class, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            y_true.extend(labels.cpu().tolist())
            y_pred.extend(preds.cpu().tolist())
            # Cleanup after each batch
            del inputs, labels, output_class, preds
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy:.2f}%")

    class_names = label_encoder.classes_
    cm = confusion_matrix(y_true, y_pred)
    cm_normalized = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True) * 100
    num_classes = len(class_names)

    fig, ax = plt.subplots(figsize=(max(10, num_classes * 0.8), max(8, num_classes * 0.6)))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_normalized, display_labels=class_names)
    disp.plot(cmap="Blues", values_format=".1f", ax=ax)
    ax.set_xticklabels(class_names, rotation=45, ha="right", va="top", fontsize=max(8, 12 - num_classes // 5))
    ax.set_yticklabels(class_names, rotation=0, fontsize=max(8, 12 - num_classes // 5))
    ax.set_xlabel("Predicted Label", fontsize=12, labelpad=10)
    ax.set_ylabel("True Label", fontsize=12, labelpad=10)
    ax.set_title("Confusion Matrix (Percentage)", fontsize=14)
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2 + num_classes * 0.005)
    plt.savefig(output_path + "test_conf_matrix.png")
    plt.close()
    # Cleanup after plotting
    gc.collect()

    return accuracy

### 4 Main

In [None]:
# Setup and Execution
data_path = "C:\\Apps\\Code\\aimc-spec-7\\preprocessed_images\\vit\\"
os.makedirs(data_path, exist_ok=True)

output_path = "C:\\Apps\\Code\\ViT\\"
os.makedirs(output_path, exist_ok=True)

snr_range = [10, 5, 0, -2, -4, -6, -8, -10, -12, -14, -16, -18, -20]
# snr_range = [10]

mod_types = [
    "FM",
    # "PM",
    # "HYBRID",
]
input_parameters = {
    "epoch_count": 200,
    "learning_rate": 1e-4,
    "csv": "vit"
}

In [None]:
# Collect training and testing data from all SNRs
X_train_list = []
y_train_list = []
test_data_by_snr = {}

for snr in snr_range:
    input_data_folder = data_path + f"snr_{snr}"
    print(f"Loading {input_data_folder}")
    X, y = load_algorithm_snr_h5s(input_data_folder, mod_types)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=42
    )
    X_train_list.append(X_train)
    y_train_list.append(y_train)
    test_data_by_snr[snr] = (X_test, y_test)
    # Cleanup after splitting
    del X, y
    gc.collect()

# Combine training data
print(f"Combining training data")    
X_train_all = np.concatenate(X_train_list, axis=0)
y_train_all = np.concatenate(y_train_list, axis=0)
# Cleanup after concatenation
del X_train_list, y_train_list
gc.collect()

In [None]:
# Label encoding
label_encoder = LabelEncoder()
y_train_all_encoded = label_encoder.fit_transform(y_train_all)

In [None]:
# Save label encoder
mds = "ALL" if len(mod_types) == 3 else mod_types[0]
output_data_folder = output_path + f"snr_all_mds_{mds}_e{input_parameters['epoch_count']}_lr{input_parameters['learning_rate']}\\"
os.makedirs(output_data_folder, exist_ok=True)
joblib.dump(label_encoder, output_data_folder + "label_encoder.pkl")

In [None]:
# Prepare training DataLoader
train_loader = prepare_dataloader(
    X_train_all,
    y_train_all_encoded,
    batch_size=32,
    shuffle=True,
    num_workers=2
)

# Cleanup after creating DataLoader
del X_train_all, y_train_all, y_train_all_encoded
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
# Initialize model
model = MainModel(
    num_classes=len(label_encoder.classes_),
    image_size=(184, 276),
    patch_size=23,
    dim=64,
    depth=8,
    heads=4,
    mlp_dims=(2048, 1024)
).to(device)

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=input_parameters["learning_rate"])
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=45, gamma=0.1)

In [None]:
# Train the model
# print(f"train")
start_time = time.time()
loss_history = train_model(
    model=model,
    train_loader=train_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    epochs=input_parameters["epoch_count"],
    patience=50,
    model_save_path=output_data_folder + "best_model.pth"
)
time_taken = time.time() - start_time

# Save loss history and plot
np.savetxt(output_data_folder + "loss_history.csv", loss_history, delimiter=",")
plot_loss_curve(loss_history, output_data_folder)
display_confusion_matrix(model, train_loader, device, output_data_folder)

# Evaluate on each SNR's test set
results = []
for snr in snr_range:
    X_test, y_test = test_data_by_snr[snr]
    y_test_encoded = label_encoder.transform(y_test)
    test_loader = prepare_dataloader(
        X_test,
        y_test_encoded,
        batch_size=32,
        shuffle=False,
        num_workers=2,
    )
    acc = evaluate_model(model, test_loader, label_encoder, device, output_data_folder + f"snr_{snr}_")
    results.append({
        "Algorithm": "ViT",
        "SNR": snr,
        "Modulations": mds,
        "Accuracy (%)": acc,
        "Time Taken (Minutes)": round(time_taken / 60, 2),
        "Learning Rate": input_parameters["learning_rate"],
        "Epoch Count": f"{200} / {input_parameters['epoch_count']}"
    })
    # Cleanup after evaluation
    del X_test, y_test, y_test_encoded, test_loader
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Cleanup after loop
del test_data_by_snr
gc.collect()

# Save results to CSV
df = pd.DataFrame(results)
df.to_csv(output_data_folder + "test_results.csv", index=False)
# Final cleanup
del model, criterion, optimizer, scheduler, label_encoder, results, df
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()


In [None]:
# Final cleanup after execution
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()