In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import gc
from scipy import fftpack
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard

# Memory management
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# Step 1: Load data with memory mapping
print("Mapping datasets to memory...")
X_mass = np.load("/home/amanbh/projects/tf217/DISS/Maps_Mcdm_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
X_vel  = np.load("/home/amanbh/projects/tf217/DISS/Maps_Vcdm_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
Y_gas  = np.load("/home/amanbh/projects/tf217/DISS/Maps_Mgas_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
Y_temp = np.load("/home/amanbh/projects/tf217/DISS/Maps_T_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')

print(f"Total samples: {len(X_mass)}")

# Create train, validation, and test indices
indices = np.arange(len(X_mass))
train_indices, temp_indices = train_test_split(indices, test_size=0.2, random_state=42)
val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

# Step 2: Optimized Normalization (using simple but effective approach)
def normalize_batch(batch, epsilon=1e-8):
    """Simple log normalization that worked well in Model 2"""
    min_val = np.min(batch)
    max_val = np.max(batch)
    return np.log1p(batch - min_val + epsilon) / np.log1p(max_val - min_val + epsilon)

# Step 3: Create Data Generators with Selective Augmentation
def train_data_generator():
    while True:  # Infinite generator for multiple epochs
        indices_copy = train_indices.copy()
        np.random.shuffle(indices_copy)
        
        for i in indices_copy:
            X1 = normalize_batch(X_mass[i])
            X2 = normalize_batch(X_vel[i])
            Y1 = normalize_batch(Y_gas[i])
            Y2 = normalize_batch(Y_temp[i])
            
            X = np.stack([X1, X2], axis=-1)
            Y = np.stack([Y1, Y2], axis=-1)
            
            # Original sample
            yield (X, Y)
            
            # Selective data augmentation (only horizontal/vertical flips - no rotations)
            # This adds diversity without excessive transformations
            if np.random.rand() > 0.7:  # 30% chance of augmentation
                if np.random.rand() > 0.5:
                    X_aug = np.flip(X, axis=0)
                    Y_aug = np.flip(Y, axis=0)
                else:
                    X_aug = np.flip(X, axis=1)
                    Y_aug = np.flip(Y, axis=1)
                    
                yield (X_aug, Y_aug)

def val_data_generator():
    for i in val_indices:
        X1 = normalize_batch(X_mass[i])
        X2 = normalize_batch(X_vel[i])
        Y1 = normalize_batch(Y_gas[i])
        Y2 = normalize_batch(Y_temp[i])
        
        yield (np.stack([X1, X2], axis=-1), np.stack([Y1, Y2], axis=-1))

def test_data_generator():
    for i in test_indices:
        X1 = normalize_batch(X_mass[i])
        X2 = normalize_batch(X_vel[i])
        Y1 = normalize_batch(Y_gas[i])
        Y2 = normalize_batch(Y_temp[i])
        
        yield (np.stack([X1, X2], axis=-1), np.stack([Y1, Y2], axis=-1))

# Step 4: Create TensorFlow Datasets
train_dataset = tf.data.Dataset.from_generator(
    train_data_generator, 
    output_signature=(
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32),
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32)
    )
)

val_dataset = tf.data.Dataset.from_generator(
    val_data_generator, 
    output_signature=(
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32),
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32)
    )
)

test_dataset = tf.data.Dataset.from_generator(
    test_data_generator, 
    output_signature=(
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32),
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32)
    )
)

# Step 5: Optimize Dataset Pipeline
train_dataset = train_dataset.batch(8).prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(8).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(8).prefetch(tf.data.AUTOTUNE)

# Step 6: Modified Loss Function (lighter on structure preservation)
def balanced_loss(y_true, y_pred):
    """Combined loss with a lighter weight on structure preservation"""
    # MSE for pixel-wise accuracy (primary objective based on Model 2's success)
    mse_loss = tf.reduce_mean(tf.square(y_true - y_pred))
    
    # Light gradient penalty (to preserve some structure)
    dy_true, dx_true = tf.image.image_gradients(y_true)
    dy_pred, dx_pred = tf.image.image_gradients(y_pred)
    grad_loss = tf.reduce_mean(tf.abs(dy_true - dy_pred)) + tf.reduce_mean(tf.abs(dx_true - dx_pred))
    
    # Combined loss with MSE as the primary focus
    return mse_loss + 0.5 * grad_loss

# Step 7: Simplified U-Net with selective advanced features
def improved_unet_model(input_shape=(256, 256, 2)):
    inputs = Input(input_shape)
    
    # Encoder path - keeping it simple like Model 2
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
    # Add lightweight residual connection (from Model 1)
    res2 = Conv2D(128, 1, padding='same')(pool1)
    conv2 = Add()([conv2, res2])
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
    # Add lightweight residual connection
    res3 = Conv2D(256, 1, padding='same')(pool2)
    conv3 = Add()([conv3, res3])
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    # Bridge
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
    # Add dropout for regularization
    conv4 = Dropout(0.2)(conv4)
    
    # Decoder path
    up3 = UpSampling2D(size=(2, 2))(conv4)
    up3 = Conv2D(256, 2, activation='relu', padding='same')(up3)
    # Simplified attention - focus only on the important regions
    gate3 = Conv2D(1, 1, activation='sigmoid')(conv3)
    att3 = Multiply()([conv3, gate3])
    merge3 = Concatenate()([up3, att3])
    
    conv5 = Conv2D(256, 3, activation='relu', padding='same')(merge3)
    conv5 = Conv2D(256, 3, activation='relu', padding='same')(conv5)
    
    up2 = UpSampling2D(size=(2, 2))(conv5)
    up2 = Conv2D(128, 2, activation='relu', padding='same')(up2)
    merge2 = Concatenate()([up2, conv2])
    
    conv6 = Conv2D(128, 3, activation='relu', padding='same')(merge2)
    conv6 = Conv2D(128, 3, activation='relu', padding='same')(conv6)
    
    up1 = UpSampling2D(size=(2, 2))(conv6)
    up1 = Conv2D(64, 2, activation='relu', padding='same')(up1)
    merge1 = Concatenate()([up1, conv1])
    
    conv7 = Conv2D(64, 3, activation='relu', padding='same')(merge1)
    conv7 = Conv2D(64, 3, activation='relu', padding='same')(conv7)
    
    # Output layer - linear activation for physical values
    outputs = Conv2D(2, 1, activation='linear')(conv7)
    
    model = Model(inputs, outputs)
    
    # Compile with appropriate optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)  # Higher initial learning rate
    model.compile(
        optimizer=optimizer,
        loss='mse',  # Using MSE as primary loss based on Model 2's success
        metrics=['mae', 'mse']  # Track both metrics
    )
    
    return model

# Step 8: Initialize and Train Model with Callbacks
model = improved_unet_model()
model.summary()

# Callbacks
callbacks = [
    # Early stopping
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    ),
    # Learning rate scheduler
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    ),
    # Model checkpoint
    ModelCheckpoint(
        filepath='best_improved_model.keras',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    # Tensorboard logging
    TensorBoard(log_dir='./logs/improved_model')
]

# Train with validation
history = model.fit(
    train_dataset,
    epochs=50,  # More epochs with early stopping
    validation_data=val_dataset,
    callbacks=callbacks
)

# Save the final model
model.save('final_improved_model.keras')

# Step 9: Evaluation Functions
def evaluate_model(model, dataset, num_samples=10):
    """Evaluate model performance on dataset"""
    # Collect predictions and ground truth
    all_mae = []
    all_mse = []
    all_rmse = []
    all_psnr = []
    all_ssim = []
    
    for X_batch, Y_batch in dataset.take(num_samples):
        Y_pred = model.predict(X_batch)
        
        # Calculate metrics batch-wise
        mae = tf.reduce_mean(tf.abs(Y_batch - Y_pred))
        mse = tf.reduce_mean(tf.square(Y_batch - Y_pred))
        rmse = tf.sqrt(mse)
        
        # Convert to numpy for further calculations
        Y_batch_np = Y_batch.numpy()
        Y_pred_np = Y_pred.numpy()
        
        # Calculate PSNR (Peak Signal-to-Noise Ratio)
        max_val = 1.0  # Normalized data
        psnr = 20 * np.log10(max_val) - 10 * np.log10(np.mean(np.square(Y_batch_np - Y_pred_np)))
        
        # For SSIM, process each channel separately
        ssim_values = []
        for b in range(Y_batch_np.shape[0]):
            for c in range(Y_batch_np.shape[-1]):
                # Using TensorFlow's SSIM implementation
                ssim = tf.image.ssim(
                    tf.expand_dims(Y_batch_np[b, :, :, c], -1),
                    tf.expand_dims(Y_pred_np[b, :, :, c], -1),
                    max_val=1.0
                )
                ssim_values.append(ssim.numpy())
        
        all_mae.append(mae.numpy())
        all_mse.append(mse.numpy())
        all_rmse.append(rmse.numpy())
        all_psnr.append(psnr)
        all_ssim.append(np.mean(ssim_values))
    
    # Calculate averages
    avg_mae = np.mean(all_mae)
    avg_mse = np.mean(all_mse)
    avg_rmse = np.mean(all_rmse)
    avg_psnr = np.mean(all_psnr)
    avg_ssim = np.mean(all_ssim)
    
    print(f"Model Evaluation Metrics (Average):")
    print(f"MAE: {avg_mae:.6f}")
    print(f"MSE: {avg_mse:.6f}")
    print(f"RMSE: {avg_rmse:.6f}")
    print(f"PSNR: {avg_psnr:.6f}")
    print(f"SSIM: {avg_ssim:.6f}")
    
    return {
        'mae': avg_mae,
        'mse': avg_mse,
        'rmse': avg_rmse,
        'psnr': avg_psnr,
        'ssim': avg_ssim
    }

def visualize_predictions(model, dataset, num_samples=3):
    """Visualize model predictions"""
    for X_batch, Y_batch in dataset.take(1):  # Get one batch
        # Select a few samples from the batch
        for i in range(min(num_samples, X_batch.shape[0])):
            X = X_batch[i:i+1]
            Y_true = Y_batch[i]
            Y_pred = model.predict(X)[0]
            
            # Visualization
            fig, axes = plt.subplots(2, 3, figsize=(15, 10))
            
            # Input data
            axes[0,0].imshow(X[0,:,:,0], cmap='viridis')
            axes[0,0].set_title("Input: Dark Matter Density")
            axes[0,1].imshow(X[0,:,:,1], cmap='viridis')
            axes[0,1].set_title("Input: Dark Matter Velocity")
            
            # Ground truth vs predictions - Gas density
            im1 = axes[0,2].imshow(Y_true[:,:,0], cmap='viridis')
            axes[0,2].set_title("Ground Truth: Gas Density")
            im2 = axes[1,0].imshow(Y_pred[:,:,0], cmap='viridis')
            axes[1,0].set_title("Predicted: Gas Density")
            
            # Add colorbar for comparison
            plt.colorbar(im1, ax=axes[0,2])
            plt.colorbar(im2, ax=axes[1,0])
            
            # Ground truth vs predictions - Temperature
            im3 = axes[1,1].imshow(Y_true[:,:,1], cmap='viridis')
            axes[1,1].set_title("Ground Truth: Temperature")
            im4 = axes[1,2].imshow(Y_pred[:,:,1], cmap='viridis')
            axes[1,2].set_title("Predicted: Temperature")
            
            # Add colorbar for comparison
            plt.colorbar(im3, ax=axes[1,1])
            plt.colorbar(im4, ax=axes[1,2])
            
            plt.tight_layout()
            plt.savefig(f"prediction_sample_{i}.png")
            plt.show()

def analyze_power_spectrum(model, dataset):
    """Analyze power spectrum of predictions vs ground truth"""
    for X_batch, Y_batch in dataset.take(1):
        # Get predictions for the batch
        Y_pred = model.predict(X_batch)
        
        # Select first sample from batch
        Y_true = Y_batch[0].numpy()
        Y_pred = Y_pred[0]
        
        # Calculate power spectra for gas density field
        true_field = Y_true[:,:,0]
        pred_field = Y_pred[:,:,0]
        
        # Calculate 2D power spectra
        true_fft = fftpack.fftshift(fftpack.fft2(true_field))
        pred_fft = fftpack.fftshift(fftpack.fft2(pred_field))
        
        true_power = np.abs(true_fft)**2
        pred_power = np.abs(pred_fft)**2
        
        # Radial average for 1D power spectrum
        rows, cols = true_field.shape
        center_row, center_col = rows // 2, cols // 2
        
        y, x = np.indices((rows, cols))
        r = np.sqrt((x - center_col)**2 + (y - center_row)**2)
        r = r.astype(int)
        
        # Calculate the radial averages
        true_radial_prof = np.bincount(r.ravel(), true_power.ravel()) / np.bincount(r.ravel())
        pred_radial_prof = np.bincount(r.ravel(), pred_power.ravel()) / np.bincount(r.ravel())
        
        # Plot
        plt.figure(figsize=(12, 10))
        
        # 1D power spectra
        plt.loglog(true_radial_prof, label='Ground Truth')
        plt.loglog(pred_radial_prof, label='Predicted', linestyle='--')
        plt.xlabel("Wavenumber k")
        plt.ylabel("Power P(k)")
        plt.legend()
        plt.title("Radially Averaged Power Spectra")
        plt.grid(True, which="both", ls="-", alpha=0.2)
        
        plt.tight_layout()
        plt.savefig("power_spectrum_comparison.png")
        plt.show()

# Evaluate model on test set
print("Evaluating model on test set...")
test_metrics = evaluate_model(model, test_dataset)

# Visualize some predictions
print("Visualizing predictions...")
visualize_predictions(model, test_dataset)

# Analyze power spectrum
print("Analyzing power spectrum...")
analyze_power_spectrum(model, test_dataset)

2025-03-11 00:51:44.415402: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-11 00:51:44.531376: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741654304.637992   65443 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741654304.668109   65443 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-11 00:51:44.866430: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Mapping datasets to memory...
Total samples: 15000


I0000 00:00:1741654315.088176   65443 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4269 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4070, pci bus id: 0000:01:00.0, compute capability: 8.9


Epoch 1/50


I0000 00:00:1741654320.973260   65613 service.cc:148] XLA service 0x7f442c020f70 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741654320.973866   65613 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 4070, Compute Capability 8.9
2025-03-11 00:52:01.455822: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1741654322.616358   65613 cuda_dnn.cc:529] Loaded cuDNN version 90300
2025-03-11 00:52:08.098018: E external/local_xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng14{} for conv (f32[8,64,256,256]{3,2,1,0}, u8[0]{0}) custom-call(f32[8,64,256,256]{3,2,1,0}, f32[64,64,3,3]{3,2,1,0}, f32[64]{0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_r

   9118/Unknown [1m1210s[0m 129ms/step - loss: 0.1446 - mae: 0.0765 - mse: 0.1446

KeyboardInterrupt: 

In [2]:
tf.keras.backend.clear_session()