In [None]:
import sys 
import os 
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import yaml
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import logging
import time
from typing import Dict, Any
import matplotlib.pyplot as plt

from dataloader.dataloader import get_sar_dataloader, SARTransform
from model.model_utils import get_model_from_configs, create_model_with_pretrained
from training.training_loops import get_training_loop_by_model_name
from training.visualize import save_results_and_metrics, get_full_image_and_prediction
from sarpyx.utils.losses import get_loss_function
from training_script import setup_logging, load_config
from inference_script import create_test_dataloader
import matplotlib.pyplot as plt
import numpy as np

def display_inference_results(input_data, gt_data, pred_data, figsize=(20, 6), vminmax=(0, 1000)):
    """
    Display input, ground truth, and prediction in a 3-column grid.
    
    Args:
        input_data: Input data from the dataset
        gt_data: Ground truth data
        pred_data: Model prediction
        figsize: Figure size
        vminmax: Value range for visualization
    """
    # Convert tensors to numpy if needed
    if hasattr(input_data, 'numpy'):
        input_data = input_data.cpu().numpy()
    if hasattr(gt_data, 'numpy'):
        gt_data = gt_data.cpu().numpy()
    if hasattr(pred_data, 'numpy'):
        pred_data = pred_data.cpu().numpy()
    
    # Function to get magnitude visualization (similar to get_sample_visualization)
    def get_magnitude_vis(data, vminmax):
        if np.iscomplexobj(data):
            magnitude = np.abs(data)
        else:
            magnitude = data
        
        if vminmax == 'auto':
            vmin, vmax = np.percentile(magnitude, [2, 98])
        elif isinstance(vminmax, tuple):
            vmin, vmax = vminmax
        else:
            vmin, vmax = np.min(magnitude), np.max(magnitude)
        
        return magnitude, vmin, vmax
    
    # Prepare visualizations
    imgs = []
    
    # Input data
    img, vmin, vmax = get_magnitude_vis(input_data, vminmax)
    imgs.append({'name': 'Input (RCMC)', 'img': img, 'vmin': vmin, 'vmax': vmax})
    
    # Ground truth
    img, vmin, vmax = get_magnitude_vis(gt_data, vminmax)
    imgs.append({'name': 'Ground Truth (AZ)', 'img': img, 'vmin': vmin, 'vmax': vmax})
    
    # Prediction
    img, vmin, vmax = get_magnitude_vis(pred_data, vminmax)
    imgs.append({'name': 'Prediction (AZ)', 'img': img, 'vmin': vmin, 'vmax': vmax})
    
    # Create the plot
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    for i in range(3):
        im = axes[i].imshow(
            imgs[i]['img'],
            aspect='auto',
            cmap='viridis',
            vmin=imgs[i]['vmin'],
            vmax=imgs[i]['vmax']
        )
        
        axes[i].set_title(f"{imgs[i]['name']}")
        axes[i].set_xlabel('Range')
        axes[i].set_ylabel('Azimuth')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=8)
        
        # Set equal aspect ratio
        axes[i].set_aspect('equal', adjustable='box')
    
    plt.tight_layout()
    plt.show()

args = argparse.Namespace(
    config="rv_transformer_autoregressive.yaml",
    device="cuda", 
    batch_size=64,
    save_dir="./visualizations",
    max_batches=64,
    max_samples_per_batch=4,
    mode="autoregressive",
    pretrained_path=os.path.join(os.getcwd(), '..', 'results', 'rv_autoregressive','sar_transformer_best.pth'), 
    learning_rate=1e-4, 
    num_epochs=50
)

# # Setup logging
logger = setup_logging()
#logger.info(f"Starting visualization with config: {args.config}")

# Load configuration
config = load_config(Path(args.config), args)

# Extract configurations
dataloader_cfg = config['dataloader']
training_cfg = config.get('training', {})

# Override save directory
save_dir = args.save_dir or training_cfg.get('save_dir', './visualizations')

# Log configuration summary
logger.info("Configuration Summary:")
logger.info(f"  Data directory: {dataloader_cfg.get('data_dir', 'Not specified')}")
logger.info(f"  Level from: {dataloader_cfg.get('level_from', 'rcmc')}")
logger.info(f"  Level to: {dataloader_cfg.get('level_to', 'az')}")
logger.info(f"  Patch size: {dataloader_cfg.get('patch_size', [1000, 1])}")
logger.info(f"  Batch size: {dataloader_cfg.get('test', {}).get('batch_size', 'Not specified')}")
logger.info(f"  Save directory: {save_dir}")

# Create test dataloader
logger.info("Creating test dataloader...")
try:
    test_loader = create_test_dataloader(dataloader_cfg)
    logger.info(f"Created test dataloader with {len(test_loader)} batches")
    logger.info(f"Dataset contains {len(test_loader.dataset)} samples")
except Exception as e:
    logger.error(f"Failed to create test dataloader: {str(e)}")
    raise

try:
    model = create_model_with_pretrained(config['model'], pretrained_path=args.pretrained_path, device=args.device)
except Exception as e:
    logger.error(f"Failed to load model: {str(e)}")
    raise

# Visualize samples
logger.info("Starting sample visualization...")
try:
    inference_fn = get_training_loop_by_model_name("rv_transformer_autoregressive", model=model, save_dir=save_dir, mode=args.mode, loss_fn_name="mse").forward_pass
    gt, pred, input = get_full_image_and_prediction(
        dataset=test_loader.dataset,
        zfile=0,
        inference_fn=inference_fn,
        max_samples_per_prod=args.max_samples_per_batch,
        return_input=True, 
        device="cuda"
    )
    display_inference_results(
        input_data=input,
        gt_data=gt,
        pred_data=pred,
        figsize=(20, 6),
        vminmax=(0, 1000)  # Adjust this range based on your data
    )
    
    logger.info("Visualization completed successfully!")
    logger.info(f"Check the visualizations in: {save_dir}")
    
except Exception as e:
    logger.error(f"Visualization failed with error: {str(e)}")
    raise

  from .autonotebook import tqdm as notebook_tqdm
INFO:training_script:Configuration Summary:
INFO:training_script:  Data directory: /Data/sar_focusing
INFO:training_script:  Level from: rcmc
INFO:training_script:  Level to: az
INFO:training_script:  Patch size: [1000, 1]
INFO:training_script:  Batch size: 64
INFO:training_script:  Save directory: ./visualizations
INFO:training_script:Creating test dataloader...


{'data_dir': '/Data/sar_focusing', 'epochs': 200, 'lr': '1e-3', 'model': {'name': 'rv_transformer', 'seq_len': 1000, 'input_dim': 2000, 'model_dim': 1024, 'num_layers': 4, 'num_heads': 4, 'ff_dim': 256, 'dropout': 0.1, 'dim_head': 16, 'mode': 'autoregressive', 'dim': 1024, 'depth': 4, 'heads': 4, 'ff_mult': 1}, 'training': {'patience': 20, 'delta': 0.001, 'mode': 'autoregressive', 'device': 'cuda', 'batch_size': 64, 'learning_rate': 0.0001, 'num_epochs': 50, 'save_dir': './visualizations'}, 'dataloader': {'level_from': 'rcmc', 'level_to': 'az', 'num_workers': 0, 'patch_mode': 'rectangular', 'patch_size': [1000, 1], 'buffer': [1000, 1000], 'stride': [300, 1], 'shuffle_files': False, 'complex_valued': False, 'save_samples': False, 'backend': 'zarr', 'verbose': False, 'cache_size': 1000, 'online': True, 'concatenate_patches': True, 'concat_axis': 0, 'positional_encoding': True, 'train': {'batch_size': 64, 'samples_per_prod': 1, 'patch_order': 'row', 'max_products': 1, 'pattern': '*2023*.z

INFO:training_script:Created test dataloader with 0 batches
INFO:training_script:Dataset contains 5000 samples


Loading pretrained weights from: /Data/gdaga/sarpyx_new/sarpyx/training/../results/rv_autoregressive/sar_transformer_best.pth


INFO:training_script:Starting sample visualization...


Successfully loaded 125 parameters


ERROR:training_script:Visualization failed with error: TrainerBase.forward_pass() missing 2 required positional arguments: 'y' and 'device'


TypeError: TrainerBase.forward_pass() missing 2 required positional arguments: 'y' and 'device'