In [None]:
import sys 
import os 
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import yaml
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import logging
import time
from typing import Dict, Any
import matplotlib.pyplot as plt

from dataloader.dataloader import get_sar_dataloader, SARTransform
from model.model_utils import get_model_from_configs, create_model_with_pretrained
from training.training_loops import TrainRVTransformer, TrainCVTransformer, TrainSSM
from training.visualize import save_results_and_metrics
from sarpyx.utils.losses import get_loss_function
from training_script import setup_logging, load_config
from inference_script import create_test_dataloader, visualize_batch_samples
# #parser = argparse.ArgumentParser(description='Visualize SAR data samples from test dataloader')
# config = "rv_transformer_autoregressive.yaml"
# #parser.add_argument('--config', type=str, default="rv_transformer_autoregressive.yaml", help='Path to configuration file')
# device = "cuda"
# #parser.add_argument('--device', type=str, default='cpu', help='Device override (cpu/cuda)')
# batch_size = 64
# #parser.add_argument('--batch_size', type=int, default=None, help='Batch size override')
# save_dir = "./visualizations"
# #parser.add_argument('--save_dir', type=str, default='./visualizations', help='Save directory for visualizations')
# max_batches = 64
# #parser.add_argument('--max_batches', type=int, default=5, help='Maximum number of batches to visualize')
# max_samples_per_batch = 4
# #parser.add_argument('--max_samples_per_batch', type=int, default=4, help='Maximum samples per batch')

# #args = parser.parse_args()
# pretrained_path = os.path.join(os.getcwd(), '..', 'results', 'rv_autoregressive','sar_transform_best.pth')

args = argparse.Namespace(
    config="rv_transformer_autoregressive.yaml",
    device="cuda", 
    batch_size=64,
    save_dir="./visualizations",
    max_batches=64,
    max_samples_per_batch=4,
    mode="autoregressive",
    pretrained_path=os.path.join(os.getcwd(), '..', 'results', 'rv_autoregressive','sar_transformer_best.pth'), 
    learning_rate=1e-4, 
    num_epochs=50
)

# # Setup logging
logger = setup_logging()
#logger.info(f"Starting visualization with config: {args.config}")

# Load configuration
config = load_config(Path(args.config), args)

# Extract configurations
dataloader_cfg = config['dataloader']
training_cfg = config.get('training', {})

# Override save directory
save_dir = args.save_dir or training_cfg.get('save_dir', './visualizations')

# Log configuration summary
logger.info("Configuration Summary:")
logger.info(f"  Data directory: {dataloader_cfg.get('data_dir', 'Not specified')}")
logger.info(f"  Level from: {dataloader_cfg.get('level_from', 'rcmc')}")
logger.info(f"  Level to: {dataloader_cfg.get('level_to', 'az')}")
logger.info(f"  Patch size: {dataloader_cfg.get('patch_size', [1000, 1])}")
logger.info(f"  Batch size: {dataloader_cfg.get('test', {}).get('batch_size', 'Not specified')}")
logger.info(f"  Save directory: {save_dir}")

# Create test dataloader
logger.info("Creating test dataloader...")
try:
    test_loader = create_test_dataloader(dataloader_cfg)
    logger.info(f"Created test dataloader with {len(test_loader)} batches")
    logger.info(f"Dataset contains {len(test_loader.dataset)} samples")
except Exception as e:
    logger.error(f"Failed to create test dataloader: {str(e)}")
    raise

try:
    model = create_model_with_pretrained(config['model'], pretrained_path=args.pretrained_path, device=args.device)
except Exception as e:
    logger.error(f"Failed to load model: {str(e)}")
    raise

# Visualize samples
logger.info("Starting sample visualization...")
try:
    visualize_batch_samples(
        model=model,
        test_loader=test_loader,
        save_dir=save_dir,
        max_batches=args.max_batches,
        max_samples_per_batch=args.max_samples_per_batch
    )
    
    logger.info("Visualization completed successfully!")
    logger.info(f"Check the visualizations in: {save_dir}")
    
except Exception as e:
    logger.error(f"Visualization failed with error: {str(e)}")
    raise

2025-08-18 15:07:08,938 - INFO - Configuration Summary:
2025-08-18 15:07:08,939 - INFO -   Data directory: /Data/sar_focusing
2025-08-18 15:07:08,939 - INFO -   Level from: rcmc
2025-08-18 15:07:08,940 - INFO -   Level to: az
2025-08-18 15:07:08,940 - INFO -   Patch size: [1000, 1]
2025-08-18 15:07:08,941 - INFO -   Batch size: 64
2025-08-18 15:07:08,941 - INFO -   Save directory: ./visualizations
2025-08-18 15:07:08,941 - INFO - Creating test dataloader...


{'data_dir': '/Data/sar_focusing', 'epochs': 200, 'lr': '1e-3', 'model': {'name': 'rv_transformer', 'seq_len': 1000, 'input_dim': 2000, 'model_dim': 1024, 'num_layers': 4, 'num_heads': 4, 'ff_dim': 256, 'dropout': 0.1, 'dim_head': 16, 'mode': 'autoregressive', 'dim': 1024, 'depth': 4, 'heads': 4, 'ff_mult': 1}, 'training': {'patience': 20, 'delta': 0.001, 'mode': 'autoregressive', 'device': 'cuda', 'batch_size': 64, 'learning_rate': 0.0001, 'num_epochs': 50, 'save_dir': './visualizations'}, 'dataloader': {'level_from': 'rcmc', 'level_to': 'az', 'num_workers': 0, 'patch_mode': 'rectangular', 'patch_size': [1000, 1], 'buffer': [1000, 1000], 'stride': [300, 1], 'shuffle_files': False, 'complex_valued': False, 'save_samples': False, 'backend': 'zarr', 'verbose': False, 'cache_size': 1000, 'online': True, 'concatenate_patches': True, 'concat_axis': 0, 'positional_encoding': True, 'train': {'batch_size': 64, 'samples_per_prod': 1, 'patch_order': 'row', 'max_products': 1, 'pattern': '*2023*.z

2025-08-18 15:07:11,132 - INFO - Created test dataloader with 0 batches
2025-08-18 15:07:11,133 - INFO - Dataset contains 1000 samples
2025-08-18 15:07:11,624 - ERROR - Failed to load model: Checkpoint not found: /Data/gdaga/sarpyx_new/sarpyx/training/../results/rv_autoregressive/sar_transform_best.pth


FileNotFoundError: Checkpoint not found: /Data/gdaga/sarpyx_new/sarpyx/training/../results/rv_autoregressive/sar_transform_best.pth