In [None]:
import os
os.environ["OMP_NUM_THREADS"] = "4"
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler, random_split
import numpy as np
import matplotlib.pyplot as plt
import time
import json
import random
from datetime import datetime
from tqdm.notebook import tqdm
import pennylane as qml
import seaborn as sns
from collections import Counter
import argparse
import warnings
warnings.filterwarnings('ignore')
SEED = 1369
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"PyTorch CUDA version: {torch.version.cuda}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Current GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
try:
    available_devices = qml.device.all_devices()
    print(f"Available quantum devices: {available_devices}")
except:
    from pennylane import devices
    print(f"PennyLane devices module available: {dir(devices)[:10]}...")
print("--------------------------------------------------------------")
try:
    from dataloader import Ali_DataLoader
    from dataset_analyzer import DatasetAnalyzer, analyze_cifar10_dataset
    print("Dataloader modules are available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure dataloader.py is in your working directory.")
print("--------------------------------------------------------------")
try:
    from target_resnet20 import ResNet20, train_target_model1_resnet20
    print("Target model 1:ResNet20 is available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure target_resnet20.py is in your working directory.")
print("--------------------------------------------------------------")
try:
    from target_efficientnet_b0 import EfficientNetB0, train_target_model2_efficientnet_b0
    print("Target model 2: EfficientNetB0 is available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure target_efficientnet_b0.py is in your working directory.")
print("--------------------------------------------------------------")
try:
    from quantum_base_modifier import QuantumBaseModifier
    print("1.Quantum ensemble member -quantum_base_modifier- specialized in low-frequency perturbations module is available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure quantum_base_modifier.py is in your working directory.")
try:
    from quantum_texture_attacker import QuantumTextureAttacker
    print("2.Quantum ensemble member -quantum_texture_attacker- specialized in texture-based perturbations module is available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure quantum_texture_attacker.py is in your working directory.")


try:
    from quantum_edge_disruptor import QuantumEdgeDisruptor
    print("3.Quantum ensemble member specialized in edge and detail disruption module is available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure quantum_edge_disruptor.py is in your working directory.")

try:
    from quantum_color_distorter import QuantumColorDistorter
    print("4.Quantum ensemble member specialized in color relationship disruption module is available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure quantum_color_distorter.py is in your working directory.")
try:
    from quantum_focal_attacker import QuantumFocalAttacker
    print("5.Quantum ensemble member specialized in attacking sensitive regions module is available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure quantum_focal_attacker.py is in your working directory.")
print("--------------------------------------------------------------")
try:
    from quantum_ensemble_manager import QuantumEnsembleManager
    print("Master class for managing the quantum ensemble of adversarial attackers is available!")
except ImportError as e:
    print(f"Warning: {e}")
    print("You may need to ensure quantum_ensemble_manager.py is in your working directory.")

In [None]:
data_dir = "data5"
results_dir = "quantum_ensemble_results"
target_model_1_train_result = os.path.join('target model 1 training result')
os.makedirs(os.path.join(target_model_1_train_result), exist_ok=True)
target_model_2_train_result = os.path.join('target model 2 training result')
os.makedirs(os.path.join(target_model_2_train_result), exist_ok=True)
data_res= os.path.join('dataset analysis')
os.makedirs(os.path.join(data_res), exist_ok=True)
data_dir = os.path.join(data_dir)
print(f"Loading CIFAR-10 dataset from {data_dir}")
target_model_1_path = os.path.join('target_model_1', 'resnet20_best.pth')
target_model_2_path = os.path.join('target_model_2', 'efficientnet_b0_best.pth')
run_name = f"experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(os.path.join(results_dir, run_name), exist_ok=True)
data_mydir=data_dir                     
batch_size_d = 64                      
num_workers_d=0                      
shuffle_d=True                        
pin_memory_d=True
augmentation_type_d='none'          # 'standard', 'advanced', 'randaugment', 'mixup', 'cutout', 'none'
train_percent_d=0.80                    
val_percent_d=0.10                     
test_percent_d=0.10                   
subset_percent_d=100

batch_size_t = 64
epochs_t = 300
run_name_t=run_name,
lr_t=0.001
weight_decay_t=5e-4
patience_t=15
scheduler_type_t='plateau' #Type of learning rate scheduler ('plateau', 'cosine', 'onecycle')
save_every_t=35
mixup_alpha_t=0.2

batch_size_t2 = 64
epochs_t2 = 300
run_name_t2=run_name,
lr_t2=0.008
weight_decay_t2=5e-4
patience_t2=15
scheduler_type_t2='cosine' #Type of learning rate scheduler ('plateau', 'cosine', 'onecycle')
save_every_t2=35
mixup_alpha_t2=0.2

batch_size_ens = 64
epsilon_ens = 0.01  
epochs_ens = 50
save_every_ens = 10
vis_every=25
eval_samples_ens = 5000
target_class_ens= None
print(f"Run name: {run_name}")
print(f"Results save path: {os.path.join(results_dir, run_name)}")

In [None]:
print(f"Loading CIFAR-10 dataset from {data_dir}")
data_loader = Ali_DataLoader(
    data_dir=data_mydir,               
    batch_size=batch_size_d,                     
    random_seed=42,                    
    shuffle=shuffle_d,                    
    num_workers=num_workers_d,                     
    pin_memory=pin_memory_d,                   
    classes=None,                       
    augmentation_type=augmentation_type_d,
)


print("Splitting dataset into train/validation/test sets...")
train_loader, val_loader, test_loader = data_loader.load_data(
    train_percent=train_percent_d,    
    val_percent=val_percent_d,       
    test_percent=test_percent_d,
    subset_percent= subset_percent_d   
)
data_iter = iter(train_loader)
images, labels = next(data_iter)
print(f"Min value: {images.min()}")
print(f"Max value: {images.max()}")
print(f"Mean value: {images.mean()}")
print(f"Std value: {images.std()}")
for i in range(3):
    print(f"Channel {i} - Mean: {images[:, i].mean()}, Std: {images[:, i].std()}")
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")



class_names = data_loader.get_class_names()
print(f"Class names: {class_names}")
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Data type: {images.dtype}")
print(f"Value range: [{images.min().item():.4f}, {images.max().item():.4f}]")
def visualize_batch(dataloader, num_samples=5):
    images, labels = next(iter(dataloader))
    images = images[:num_samples]
    labels = labels[:num_samples]
    images = images.numpy().transpose((0, 2, 3, 1))



    fig, axs = plt.subplots(1, num_samples, figsize=(15, 3))
    class_names = data_loader.get_class_names()
    
    for i in range(num_samples):
        axs[i].imshow(images[i])
        if isinstance(labels[i], torch.Tensor):
            label_idx = labels[i].item()
        else:
            label_idx = labels[i]
        axs[i].set_title(class_names[label_idx])
        axs[i].axis('off')
    
    plt.show()
print("\nCalculating class distribution...")
train_dist = data_loader.get_class_distribution(train_loader)
val_dist = data_loader.get_class_distribution(val_loader)
test_dist = data_loader.get_class_distribution(test_loader)
print("Training set class distribution:")
for class_name, count in train_dist.items():
    print(f"  {class_name}: {count}")
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.bar(train_dist.keys(), train_dist.values())
plt.title("Training Set Distribution")
plt.xticks(rotation=45)
plt.tight_layout()
plt.subplot(1, 3, 2)
plt.bar(val_dist.keys(), val_dist.values())
plt.title("Validation Set Distribution")
plt.xticks(rotation=45)
plt.tight_layout()
plt.subplot(1, 3, 3)
plt.bar(test_dist.keys(), test_dist.values())
plt.title("Test Set Distribution")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(os.path.join(data_res, 'cifar10_class_distribution.png'))
plt.show()
print("Displaying samples from training data with advanced augmentation:")
visualize_batch(train_loader)
print("Displaying samples from validation data (without augmentation):")
visualize_batch(val_loader)
train_distribution = data_loader.get_class_distribution(train_loader)
print("Class distribution in training data:")
for class_name, count in train_distribution.items():
    print(f"{class_name}: {count}")
print("Dataset preparation complete!")

In [None]:
analyzer_dataset = analyze_cifar10_dataset(data_loader, train_loader, val_loader, test_loader, data_res)

# Target Model 1 (ResNet20) Training & Evaluation 

In [None]:
target_model_1_exists = os.path.exists(target_model_1_path)
def evaluate_model_1(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * 10
    class_total = [0] * 10
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating model"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            for i in range(labels.size(0)):
                label = labels[i]
                pred = predicted[i]
                class_correct[label] += (pred == label).item()
                class_total[label] += 1
    accuracy = 100 * correct / total
    class_accuracies = [100 * class_correct[i] / max(1, class_total[i]) for i in range(10)]
    return accuracy, class_accuracies
print(f"Checking for existing model at: {target_model_1_path}")
if target_model_1_exists:
    print("Pre-trained model found! Loading model...")
    checkpoint_1 = torch.load(target_model_1_path, map_location=device)
    target_model_1 = ResNet20(num_classes=10).to(device)
    if 'model_state_dict' in checkpoint_1:
        target_model_1.load_state_dict(checkpoint_1['model_state_dict'])
    else:
        target_model_1.load_state_dict(checkpoint_1)
    print("Model loaded successfully.")
    print("Evaluating model on test set...")
    accuracy, class_accuracies = evaluate_model_1(target_model_1, test_loader, device)
    print(f"Test Accuracy: {accuracy:.2f}%")
    print("\nPer-class accuracy:")
    for i, class_name in enumerate(class_names):
        print(f"  {class_name}: {class_accuracies[i]:.2f}%")
    


    plt.figure(figsize=(12, 5))
    plt.bar(class_names, class_accuracies)
    plt.title("Per-Class Accuracy on Test Set")
    plt.ylabel("Accuracy (%)")
    plt.xlabel("Class")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(target_model_1_train_result, 'target_model_1_class_accuracy.png'))
    plt.show()
else:
    print("No pre-trained model found. Training a new model...")
    print("Starting the data augmentation process to train the target model in a robust manner...")
    print("Data augmentation techniques have been applied to the training dataset.")
    print("\nStarting model training...")
    print("This may take some time. Training progress will be displayed below:")
    target_model_1 = train_target_model1_resnet20(
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        run_name=run_name_t,
        epochs=epochs_t,  
        lr=lr_t,
        weight_decay=weight_decay_t,
        patience=patience_t,
        scheduler_type=scheduler_type_t,
        save_every=save_every_t,
        mixup_alpha=mixup_alpha_t
    )
    print("\nModel training completed!")
    print("\nEvaluating newly trained model...")
    accuracy, class_accuracies = evaluate_model_1(target_model_1, test_loader, device)
    print(f"Test Accuracy: {accuracy:.2f}%")
    print("\nPer-class accuracy:")
    for i, class_name in enumerate(class_names):
        print(f"  {class_name}: {class_accuracies[i]:.2f}%")
    plt.figure(figsize=(12, 5))
    plt.bar(class_names, class_accuracies)
    plt.title("Per-Class Accuracy on Test Set")
    plt.ylabel("Accuracy (%)")
    plt.xlabel("Class")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(target_model_1_train_result, 'target_model_1_class_accuracy.png'))
    plt.show()
print("\nModel architecture:")
print(target_model_1)
try:
    from torchsummary import summary
    summary_str = str(summary(target_model_1, (3, 32, 32), device=str(device)))
    with open(os.path.join(target_model_1_train_result, 'model_summary.txt'), 'w') as f:
        f.write(summary_str)
    print(f"Detailed model summary saved to {os.path.join(target_model_1_train_result, 'model_summary.txt')}")
except Exception as e:
    print(f"Could not generate detailed model summary: {e}")
print("\nTesting model on a few sample images:")
test_batch = next(iter(test_loader))
test_images, test_labels = test_batch
test_images = test_images[:5].to(device)  
test_labels = test_labels[:5].to(device)
target_model_1.eval()
with torch.no_grad():
    outputs = target_model_1(test_images)
    _, predicted = torch.max(outputs, 1)
plt.figure(figsize=(15, 3))
for i in range(5):
    img = test_images[i].cpu().numpy().transpose((1, 2, 0))
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2470, 0.2435, 0.2616])
    img = img * std + mean
    img = np.clip(img, 0, 1)
    plt.subplot(1, 5, i+1)
    plt.imshow(img)
    correct = predicted[i] == test_labels[i]
    color = "green" if correct else "red"
    plt.title(f"True: {class_names[test_labels[i]]}\nPred: {class_names[predicted[i]]}", 
              color=color)
    plt.axis('off')
plt.savefig(os.path.join(target_model_1_train_result, 'model_predictions.png'))
plt.tight_layout()
plt.show()

# Target Model 2 (EfficientNet-B0) Training & Evaluation 

In [None]:
target_model_2_exists = os.path.exists(target_model_2_path)

def evaluate_model_2(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * 10
    class_total = [0] * 10
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating model"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            for i in range(labels.size(0)):
                label = labels[i]
                pred = predicted[i]
                class_correct[label] += (pred == label).item()
                class_total[label] += 1
    
    accuracy = 100 * correct / total
    class_accuracies = [100 * class_correct[i] / max(1, class_total[i]) for i in range(10)]
    
    return accuracy, class_accuracies

print(f"Checking for existing model at: {target_model_2_path}")

if target_model_2_exists:
    print("Pre-trained model found! Loading model...")
    checkpoint_2 = torch.load(target_model_2_path, map_location=device)
    
    target_model_2 = EfficientNetB0(num_classes=10).to(device)
    
    if 'model_state_dict' in checkpoint_2:
        target_model_2.load_state_dict(checkpoint_2['model_state_dict'])
    else:
        target_model_2.load_state_dict(checkpoint_2)
    
    print("Model loaded successfully.")
    
    print("Evaluating model on test set...")
    accuracy, class_accuracies = evaluate_model_2(target_model_2, test_loader, device)
    print(f"Test Accuracy: {accuracy:.2f}%")
    
    print("\nPer-class accuracy:")
    for i, class_name in enumerate(class_names):
        print(f"  {class_name}: {class_accuracies[i]:.2f}%")
    
    plt.figure(figsize=(12, 5))
    plt.bar(class_names, class_accuracies)
    plt.title("Per-Class Accuracy on Test Set")
    plt.ylabel("Accuracy (%)")
    plt.xlabel("Class")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(target_model_2_train_result, 'target_model_2_class_accuracy.png'))
    plt.show()
else:
    print("No pre-trained model found. Training a new model...")
    
    print("Data augmentation techniques have been applied to the training dataset.")
    print("\nStarting model training...")
    print("This may take some time. Training progress will be displayed below:")
    
    target_model_2 = train_target_model2_efficientnet_b0(
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        run_name=run_name_t2,
        epochs=epochs_t2,
        lr=lr_t2,
        weight_decay=weight_decay_t2,
        patience=patience_t2,
        scheduler_type=scheduler_type_t2,
        save_every=save_every_t2,
        mixup_alpha=mixup_alpha_t2
    )
    
    print("\nModel training completed!")
    
    print("\nEvaluating newly trained model...")
    accuracy, class_accuracies = evaluate_model_2(target_model_2, test_loader, device)
    print(f"Test Accuracy: {accuracy:.2f}%")
    
    print("\nPer-class accuracy:")
    for i, class_name in enumerate(class_names):
        print(f"  {class_name}: {class_accuracies[i]:.2f}%")
    
    plt.figure(figsize=(12, 5))
    plt.bar(class_names, class_accuracies)
    plt.title("Per-Class Accuracy on Test Set")
    plt.ylabel("Accuracy (%)")
    plt.xlabel("Class")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(target_model_2_train_result, 'target_model_2_class_accuracy.png'))
    plt.show()

print("\nModel architecture:")
print(target_model_2)

try:
    from torchsummary import summary
    summary_str = str(summary(target_model_2, (3, 32, 32), device=str(device)))
    
    with open(os.path.join(target_model_2_train_result, 'model_summary.txt'), 'w') as f:
        f.write(summary_str)
    
    print(f"Detailed model summary saved to {os.path.join(target_model_2_train_result, 'model_summary.txt')}")
except Exception as e:
    print(f"Could not generate detailed model summary: {e}")

print("\nTesting model on a few sample images:")

test_batch = next(iter(test_loader))
test_images, test_labels = test_batch
test_images = test_images[:5].to(device)
test_labels = test_labels[:5].to(device)

target_model_2.eval()
with torch.no_grad():
    outputs = target_model_2(test_images)
    _, predicted = torch.max(outputs, 1)

plt.figure(figsize=(15, 3))
for i in range(5):
    img = test_images[i].cpu().numpy().transpose((1, 2, 0))
    
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2470, 0.2435, 0.2616])
    img = img * std + mean
    
    img = np.clip(img, 0, 1)
    
    plt.subplot(1, 5, i+1)
    plt.imshow(img)
    correct = predicted[i] == test_labels[i]
    color = "green" if correct else "red"
    plt.title(f"True: {class_names[test_labels[i]]}\nPred: {class_names[predicted[i]]}", 
              color=color)
    plt.axis('off')

plt.savefig(os.path.join(target_model_2_train_result, 'model_predictions.png'))
plt.tight_layout()
plt.show()

In [None]:
target_model_1 = ResNet20(num_classes=10).to(device)
checkpoint_1 = torch.load(target_model_1_path, map_location=device, weights_only=False)
if 'model_state_dict' in checkpoint_1:
    target_model_1.load_state_dict(checkpoint_1['model_state_dict'])
else:
    target_model_1.load_state_dict(checkpoint_1)
target_model_1.eval()
for param in target_model_1.parameters():
    param.requires_grad = False
print("Target model 1:ResNet20 loaded successfully.")
def evaluate_model_1(model, dataloader, num_samples=None):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            if num_samples is not None and total >= num_samples:
                break
    accuracy = 100 * correct / total
    return accuracy
accuracy = evaluate_model_1(target_model_1, test_loader)
print(f"Target model 1:ResNet20 accuracy on test data: {accuracy:.2f}%")
print("\nTarget model 1:ResNet20 ready for adversarial attacks!")

In [None]:
target_model_2 = EfficientNetB0(num_classes=10).to(device)
checkpoint_2 = torch.load(target_model_2_path, map_location=device, weights_only=False)
if 'model_state_dict' in checkpoint_2:
    target_model_2.load_state_dict(checkpoint_2['model_state_dict'])
else:
    target_model_2.load_state_dict(checkpoint_2)
target_model_2.eval()
for param in target_model_2.parameters():
    param.requires_grad = False
print("Target model 2:EfficientNet_B0 loaded successfully.")
def evaluate_model_2(model, dataloader, num_samples=None):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            if num_samples is not None and total >= num_samples:
                break
    accuracy = 100 * correct / total
    return accuracy
accuracy = evaluate_model_2(target_model_2, test_loader)
print(f"Target model 2:EfficientNet_B0 accuracy on test data: {accuracy:.2f}%")
print("\nTarget model 2:EfficientNet_B0 ready for adversarial attacks!")

In [None]:
from quantum_ensemble_manager import QuantumEnsembleManager
test_images, test_labels = next(iter(test_loader))
test_image = test_images[0].to(device)  
test_label = test_labels[0].item()
print("Testing QuantumBaseModifier...")
base_modifier = QuantumBaseModifier(n_qubits=4, n_layers=6, epsilon=epsilon_ens, device=device).to(device)
base_perturbation, base_adv_image = base_modifier(test_image.unsqueeze(0))
base_modifier.visualize_perturbation(
    test_image.cpu(),
    base_perturbation[0].cpu(),
    base_adv_image[0].cpu(),
    filename=os.path.join(results_dir, run_name, "base_modifier_test.png")
)
print("Testing QuantumTextureAttacker...")
texture_attacker = QuantumTextureAttacker(n_qubits=4, n_layers=5, epsilon=epsilon_ens, device=device).to(device)
texture_perturbation, texture_adv_image = texture_attacker(test_image.unsqueeze(0))
texture_attacker.visualize_perturbation(
    test_image.cpu(),
    texture_perturbation[0].cpu(),
    texture_adv_image[0].cpu(),
    filename=os.path.join(results_dir, run_name, "texture_attacker_test.png")
)
texture_attacker.visualize_gabor_responses(
    test_image.cpu(),
    filename=os.path.join(results_dir, run_name, "gabor_responses_test.png")
)
print("Testing QuantumEdgeDisruptor...")
edge_disruptor = QuantumEdgeDisruptor(n_qubits=3, n_layers=4, epsilon=epsilon_ens, device=device).to(device)
edge_perturbation, edge_adv_image = edge_disruptor(test_image.unsqueeze(0))
edge_disruptor.visualize_perturbation(
    test_image.cpu(),
    edge_perturbation[0].cpu(),
    edge_adv_image[0].cpu(),
    filename=os.path.join(results_dir, run_name, "edge_disruptor_test.png")
)
edge_disruptor.visualize_wavelet_decomposition(
    test_image.cpu(),
    filename=os.path.join(results_dir, run_name, "wavelet_decomposition_test.png")
)
print("Testing QuantumColorDistorter...")
color_distorter = QuantumColorDistorter(n_qubits=3, n_layers=4, epsilon=epsilon_ens, device=device).to(device)
color_perturbation, color_adv_image = color_distorter(test_image.unsqueeze(0))
color_distorter.visualize_perturbation(
    test_image.cpu(),
    color_perturbation[0].cpu(),
    color_adv_image[0].cpu(),
    filename=os.path.join(results_dir, run_name, "color_distorter_test.png")
)
color_distorter.visualize_colorspaces(
    test_image.cpu(),
    color_adv_image[0].cpu(),
    filename=os.path.join(results_dir, run_name, "colorspaces_test.png")
)
print("Testing QuantumFocalAttacker...")
focal_attacker = QuantumFocalAttacker(target_model=target_model_1, n_qubits=3, n_layers=6, n_focal_regions=6, epsilon=epsilon_ens, device=device).to(device)
focal_perturbation, focal_adv_image, gradcam_maps = focal_attacker(test_image.unsqueeze(0))
focal_attacker.visualize_perturbation(
    test_image.cpu(),
    focal_perturbation[0].cpu(),
    focal_adv_image[0].cpu(),
    gradcam_maps[0].cpu(),
    filename=os.path.join(results_dir, run_name, "focal_attacker_test.png")
)
focal_attacker.visualize_focal_clusters(
    test_image.cpu(),
    gradcam_maps[0].cpu(),
    filename=os.path.join(results_dir, run_name, "focal_clusters_test.png")
)
print("All ensemble members tested successfully.")

In [None]:
ensemble_manager = QuantumEnsembleManager(
    target_model=target_model_1,
    device=device,
    epsilon=epsilon_ens,
    run_name=run_name
)
print("Quantum Ensemble Manager created successfully.")
test_batch, test_labels = next(iter(test_loader))
test_batch = test_batch.to(device)
test_labels = test_labels.to(device)
adv_images, perturbation, member_outputs, weights = ensemble_manager(test_batch)
print(f"Initial ensemble weights: {weights.detach().cpu().numpy()}")
print(f"Average perturbation norm: {torch.norm(perturbation, p=2, dim=(1,2,3)).mean().item():.4f}")
with torch.no_grad():
    original_outputs = target_model_1(test_batch)
    adversarial_outputs = target_model_1(adv_images) 
    orig_preds = torch.argmax(original_outputs, dim=1)
    adv_preds = torch.argmax(adversarial_outputs, dim=1)
    success_rate = (orig_preds != adv_preds).float().mean().item() * 100
print(f"Initial attack success rate: {success_rate:.2f}%")
ensemble_manager.visualize_ensemble_attack(
    test_batch[0],
    test_labels[0].item(),
    os.path.join(results_dir, run_name, "ensemble_attack_test.png")
)


In [None]:
try:
    train_metrics = ensemble_manager.train_model(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=epochs_ens,
        vis_every=save_every_ens,
        save_every=save_every_ens,
        target_class= 3
    )
    print("Quantum ensemble training completed successfully!")
except KeyboardInterrupt:
    print("Training interrupted by user.")
    # Save current model state
    ensemble_manager.save_model('interrupted_model.pth')
    print("Current model state saved.")

In [None]:
ensemble_manager.visualize_kalman_filter("final")

In [None]:
try:
    ensemble_manager.load_model('best_model.pth')
    print("Best model loaded.")
except:
    print("Best model not found, using current model.")
eval_results = ensemble_manager.evaluate_on_dataset(
    test_loader,
    num_samples=eval_samples_ens
)
print("\nFinal Evaluation Results:")
print(f"Overall Attack Success Rate: {eval_results['success_rate']:.2f}%")
print(f"Perturbation Norm: {eval_results['perturbation_norm']:.4f}")
print(f"Prediction Changes: {eval_results['prediction_changes']}")
print(f"Confidence Decreases: {eval_results['confidence_decreases']}")
print("\nClass-wise Success Rates:")
class_names = data_loader.get_class_names()
for class_idx, stats in eval_results['class_success_rates'].items():
    if stats['count'] > 0:
        print(f"  {class_names[class_idx]}: {stats['success_rate']:.2f}% ({stats['success']}/{stats['count']})")
print("\nSuccess Rates of Ensemble Members:")
member_names = ["Base", "Texture", "Edge", "Color", "Focal"]
for i, rate in enumerate(eval_results['member_success_rates']):
    print(f"  {member_names[i]}: {rate:.2f}%")


In [None]:
plt.figure(figsize=(10, 6))
plt.plot(ensemble_manager.metrics['attack_success_rate'])
plt.title('Attack Success Rate During Training')
plt.xlabel('Batch')
plt.ylabel('Success Rate (%)')
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(results_dir, run_name, 'attack_success_rate.png'))
plt.show()
plt.figure(figsize=(10, 6))
plt.plot(ensemble_manager.metrics['perturbation_norm'])
plt.title('L2 Perturbation Norm During Training')
plt.xlabel('Batch')
plt.ylabel('L2 Norm')
plt.grid(True, alpha=0.3)
plt.savefig(os.path.join(results_dir, run_name, 'perturbation_norm.png'))
plt.show()
if ensemble_manager.metrics['ensemble_weights']:
    weights_array = np.array(ensemble_manager.metrics['ensemble_weights'])
    plt.figure(figsize=(10, 6))
    for i in range(5):
        plt.plot(weights_array[:, i], label=member_names[i])
    plt.title('Evolution of Ensemble Weights')
    plt.xlabel('Update Step')
    plt.ylabel('Weight')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(results_dir, run_name, 'ensemble_weights.png'))
    plt.show()
member_perf = ensemble_manager.metrics['member_performance']
if all(member_perf) and all(perf for perf in member_perf):
    plt.figure(figsize=(10, 6))
    for i in range(5):
        plt.plot(member_perf[i], label=member_names[i])
    plt.title('Performance of Ensemble Members')
    plt.xlabel('Update Step')
    plt.ylabel('Performance')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(results_dir, run_name, 'member_performance.png'))
    plt.show()
plt.figure(figsize=(12, 6))
class_success = [eval_results['class_success_rates'][i]['success_rate'] for i in range(10)]
plt.bar(class_names, class_success)
plt.title('Attack Success Rate per Class')
plt.ylabel('Success Rate (%)')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(os.path.join(results_dir, run_name, 'class_success_rates.png'))
plt.show()
samples_dir = os.path.join(results_dir, run_name, 'evaluation_samples')
if os.path.exists(samples_dir):
    sample_images = [f for f in os.listdir(samples_dir) if f.endswith('.png')][:6]
    if sample_images:
        plt.figure(figsize=(15, 12))
        for i, img_file in enumerate(sample_images):
            img = plt.imread(os.path.join(samples_dir, img_file))
            plt.subplot(2, 3, i + 1)
            plt.imshow(img)
            plt.title(img_file.split('.')[0])
            plt.axis('off')
        plt.tight_layout()
        plt.show()
