# Open-Set Recognition with RF-Diffusion: Main Experiment Notebook

This notebook serves as the main interface for running the complete OSR pipeline. It imports modules from the other project files to perform:

1.  **Phase 1:** Train the Disentangled Feature Extractor.
2.  **Phase 2:** Train the Conditional Diffusion Model.
3.  **Phase 3:** Calculate the optimal rejection threshold.
4.  **Phase 4:** Run a final, detailed evaluation on the test set.

In [None]:
import config
import torch

# --- Control Panel ---
# Set these flags to True/False to control which parts of the pipeline are executed.
# After running a phase, you can set its flag to False to avoid re-running it.
DO_PHASE_1_TRAINING = False
DO_PHASE_2_TRAINING = False
DO_PHASE_3_THRESHOLD_CALC = True
DO_PHASE_4_EVALUATION = True

print("--- EXPERIMENT CONFIGURATION ---")
print(f"KNOWN CLASSES: {config.KNOWN_CLASSES_LIST}")
print(f"KNOWN UNKNOWN (for Threshold): {config.KNOWN_UNKNOWN_CLASS}")
print(f"TEST UNKNOWN (for Evaluation): {config.TEST_UNKNOWN_CLASS}")
print(f"PHASE 1 EPOCHS: {config.TRAINING_PARAMS['phase1_epochs']}")
print(f"PHASE 2 EPOCHS: {config.TRAINING_PARAMS['phase2_epochs']}")
print("---------------------------------")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

This cell prepares the three necessary DataLoaders from your source `.mat` file.

In [None]:
from data_loader import prepare_dataloaders

# This function (defined in data_loader.py) handles all the data loading and splitting.
train_loader, threshold_loader, test_loader = prepare_dataloaders(config.TRAINING_PARAMS['batch_size'])

This cell calls the training function from `train.py`. It will train and save the `disentangled_feature_extractor.pt` model file.

In [None]:
from train import run_phase1_training

if DO_PHASE_1_TRAINING:
    run_phase1_training()
else:
    print("Skipping Phase 1 Training.")

This cell calls the training function from `train.py` for Phase 2. It loads the feature extractor from Phase 1 and uses it to train the diffusion model.

In [None]:
from train import run_phase2_training

if DO_PHASE_2_TRAINING:
    run_phase2_training()
else:
    print("Skipping Phase 2 Training.")

This phase loads both trained models and the `threshold_loader` to find the best reconstruction error threshold using the Youden's Index method.

In [None]:
from evaluate import calculate_optimal_threshold
from models.feature_extractor import DisentangledFeatureExtractor
from models.diffusion_model import tfdiff_WiFi
from utils.diffusion_helper import SignalDiffusion

optimal_threshold = None

if DO_PHASE_3_THRESHOLD_CALC:
    print("Loading models for threshold calculation...")
    
    # Load the trained models
    feature_extractor = DisentangledFeatureExtractor(
        num_classes=config.FEATURE_EXTRACTOR_PARAMS['num_classes'],
        feature_dim=config.FEATURE_EXTRACTOR_PARAMS['feature_dim']
    ).to(device)
    feature_extractor.load_state_dict(torch.load(config.PATHS['feature_extractor'], map_location=device))

    diffusion_model = tfdiff_WiFi(config.DIFFUSION_PARAMS).to(device)
    diffusion_model.load_state_dict(torch.load(config.PATHS['diffusion_model'], map_location=device))

    diffusion_helper = SignalDiffusion(config.DIFFUSION_PARAMS)
    
    # Calculate the threshold
    optimal_threshold = calculate_optimal_threshold(
        threshold_loader,
        feature_extractor,
        diffusion_model,
        diffusion_helper
    )
else:
    print("Skipping Phase 3 Threshold Calculation.")

This final phase uses the trained models and the optimal threshold to perform a detailed evaluation on the test set, which includes the completely unseen 'DSSS' class.

In [None]:
from evaluate import run_final_evaluation

if DO_PHASE_4_EVALUATION:
    if optimal_threshold is None:
        # If Phase 3 was skipped, try to load from file or use a default
        try:
            optimal_threshold = np.loadtxt("optimal_threshold.txt")
            print(f"Loaded optimal threshold from file: {optimal_threshold}")
        except IOError:
            optimal_threshold = 0.2 # Fallback to a default value
            print(f"WARNING: optimal_threshold.txt not found. Using default value: {optimal_threshold}")

    # The models should still be loaded from the previous cell if it was run.
    # If not, we would need to load them again here.
    
    print("\nReloading models for final evaluation to ensure they are fresh...")
    feature_extractor = DisentangledFeatureExtractor(
        num_classes=config.FEATURE_EXTRACTOR_PARAMS['num_classes'],
        feature_dim=config.FEATURE_EXTRACTOR_PARAMS['feature_dim']
    ).to(device)
    feature_extractor.load_state_dict(torch.load(config.PATHS['feature_extractor'], map_location=device))

    diffusion_model = tfdiff_WiFi(config.DIFFUSION_PARAMS).to(device)
    diffusion_model.load_state_dict(torch.load(config.PATHS['diffusion_model'], map_location=device))

    diffusion_helper = SignalDiffusion(config.DIFFUSION_PARAMS)

    # Run the final evaluation
    run_final_evaluation(
        test_loader,
        optimal_threshold,
        feature_extractor,
        diffusion_model,
        diffusion_helper
    )
else:
    print("Skipping Phase 4 Final Evaluation.")