In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import gc
from sklearn.model_selection import train_test_split
import json

# Create directories for model checkpoints and plots
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("plots", exist_ok=True)

# ✅ Step 1: Use `np.memmap` to load `.npy` files without RAM overload
print("Mapping datasets to memory...")
X_mass = np.load("/home/amanbh/projects/tf217/DISS/Maps_Mcdm_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
X_vel  = np.load("/home/amanbh/projects/tf217/DISS/Maps_Vcdm_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
Y_gas  = np.load("/home/amanbh/projects/tf217/DISS/Maps_Mgas_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
Y_temp = np.load("/home/amanbh/projects/tf217/DISS/Maps_T_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
print(f"Total samples: {len(X_mass)}")  # Check dataset size

# ✅ Step 2: Split data into train, validation, and test sets
# First, create indices for all samples
indices = np.arange(len(X_mass))

# Split into train (70%), validation (15%), and test (15%)
train_indices, temp_indices = train_test_split(indices, test_size=0.3, random_state=42)
val_indices, test_indices = train_test_split(temp_indices, test_size=0.5, random_state=42)

# Save indices for reproducibility and resuming training
indices_dict = {
    "train_indices": train_indices.tolist(),
    "val_indices": val_indices.tolist(),
    "test_indices": test_indices.tolist()
}

with open("data_split_indices.json", "w") as f:
    json.dump(indices_dict, f)

print(f"Train samples: {len(train_indices)}")
print(f"Validation samples: {len(val_indices)}")
print(f"Test samples: {len(test_indices)}")

# ✅ Step 3: Efficient Normalization (applied dynamically per batch)
def normalize_batch(batch, epsilon=1e-6):
    min_val = np.min(batch)
    max_val = np.max(batch)
    return np.log1p(batch - min_val + epsilon) / np.log1p(max_val - min_val + epsilon)

# ✅ Step 4: Create Memory-Efficient Data Generators for each split
def create_data_generator(indices_list):
    def generator():
        for i in indices_list:
            X1 = normalize_batch(X_mass[i])  # Dark matter mass
            X2 = normalize_batch(X_vel[i])   # Dark matter velocity
            Y1 = normalize_batch(Y_gas[i])   # Gas density
            Y2 = normalize_batch(Y_temp[i])  # Temperature
            yield (np.stack([X1, X2], axis=-1), np.stack([Y1, Y2], axis=-1))
    return generator

# ✅ Step 5: Convert Generators to TensorFlow Datasets
output_signature = (
    tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32),
    tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32)
)

train_dataset = tf.data.Dataset.from_generator(
    create_data_generator(train_indices),
    output_signature=output_signature
)

val_dataset = tf.data.Dataset.from_generator(
    create_data_generator(val_indices),
    output_signature=output_signature
)

test_dataset = tf.data.Dataset.from_generator(
    create_data_generator(test_indices),
    output_signature=output_signature
)

# ✅ Step 6: Preprocess the Data Pipelines
BATCH_SIZE = 8
train_dataset = train_dataset.shuffle(1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# ✅ Step 7: Define U-Net Model with Huber Loss
def unet_model(input_shape):
    inputs = tf.keras.layers.Input(input_shape)
    
    # Encoder
    conv1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    
    # Bottleneck
    conv4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)
    
    # Decoder
    up1 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv4)
    up1 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up1)
    merge1 = tf.keras.layers.Concatenate()([up1, conv3])
    
    up2 = tf.keras.layers.UpSampling2D(size=(2, 2))(merge1)
    up2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up2)
    merge2 = tf.keras.layers.Concatenate()([up2, conv2])
    
    up3 = tf.keras.layers.UpSampling2D(size=(2, 2))(merge2)
    up3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up3)
    merge3 = tf.keras.layers.Concatenate()([up3, conv1])
    
    outputs = tf.keras.layers.Conv2D(2, (1, 1), activation='sigmoid')(merge3)
    
    model = tf.keras.models.Model(inputs, outputs)
    
    # Using Huber loss instead of MSE
    model.compile(
        optimizer='adam',
        loss=tf.keras.losses.Huber(delta=1.0),  # Huber loss with delta=1.0
        metrics=['mae']
    )
    
    return model

# ✅ Step 8: Custom callbacks for model checkpointing and loss plotting
class LossHistory(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.val_losses = []
        self.epochs = []
        
    def on_epoch_end(self, epoch, logs=None):
        self.epochs.append(epoch)
        self.train_losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        
        # Plot losses after each epoch
        plt.figure(figsize=(10, 6))
        plt.plot(self.epochs, self.train_losses, 'b-', label='Training Loss')
        plt.plot(self.epochs, self.val_losses, 'r-', label='Validation Loss')
        plt.title('Training and Validation Loss vs Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Loss (Huber)')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'plots/loss_plot_epoch_{epoch+1}.png')
        plt.close()

# Create checkpoint callback to save model after each epoch
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoints/model_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.h5',
    save_best_only=False,  # Save all models after each epoch
    save_weights_only=False,  # Save the entire model
    monitor='val_loss',
    verbose=1
)

# Create loss history callback
loss_history = LossHistory()

# ✅ Step 9: Load model from checkpoint if available (for resuming training)
import glob

def find_latest_checkpoint():
    checkpoints = glob.glob('checkpoints/model_epoch_*.h5')
    if not checkpoints:
        return None
    
    # Extract epoch numbers from filenames
    epoch_nums = [int(cp.split('_epoch_')[1].split('_')[0]) for cp in checkpoints]
    latest_epoch = max(epoch_nums)
    latest_checkpoint = [cp for cp in checkpoints if f'_epoch_{latest_epoch:02d}_' in cp][0]
    
    return latest_checkpoint, latest_epoch

# Check if there are previous checkpoints
latest_checkpoint_info = find_latest_checkpoint()
initial_epoch = 0

if latest_checkpoint_info:
    checkpoint_path, last_epoch = latest_checkpoint_info
    print(f"Resuming training from checkpoint: {checkpoint_path}")
    model = tf.keras.models.load_model(checkpoint_path)
    initial_epoch = last_epoch + 1
else:
    print("Starting training from scratch")
    model = unet_model((256, 256, 2))  # Change input shape as needed

model.summary()

# ✅ Step 10: Train the model with validation and callbacks
EPOCHS = 10
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS,
    initial_epoch=initial_epoch,
    callbacks=[checkpoint_callback, loss_history]
)

# ✅ Step 11: Final evaluation on test set
test_loss = model.evaluate(test_dataset)
print(f"Test loss (Huber): {test_loss[0]}")
print(f"Test MAE: {test_loss[1]}")

# ✅ Step 12: Create and save final loss plot
plt.figure(figsize=(12, 8))
plt.plot(history.history['loss'], 'b-', label='Training Loss')
plt.plot(history.history['val_loss'], 'r-', label='Validation Loss')
plt.title('Training and Validation Loss vs Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss (Huber)')
plt.legend()
plt.grid(True)
plt.savefig('plots/final_loss_plot.png')
plt.show()

# Save final model
model.save('final_model.h5')
print("Training complete. Final model saved.")

2025-03-11 20:14:21.791345: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-11 20:14:22.545896: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741724063.391373  375899 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741724063.628489  375899 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-11 20:14:25.631383: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Mapping datasets to memory...
Total samples: 15000
Train samples: 10500
Validation samples: 2250
Test samples: 2250


I0000 00:00:1741724084.127035  375899 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9558 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4070, pci bus id: 0000:01:00.0, compute capability: 8.9


Starting training from scratch


Epoch 1/10


2025-03-11 20:15:15.843118: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:6: Filling up shuffle buffer (this may take a while): 130 of 1000
2025-03-11 20:15:35.842102: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:6: Filling up shuffle buffer (this may take a while): 400 of 1000
2025-03-11 20:15:45.877692: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:6: Filling up shuffle buffer (this may take a while): 560 of 1000
2025-03-11 20:16:05.918418: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:6: Filling up shuffle buffer (this may take a while): 723 of 1000
2025-03-11 20:16:26.223229: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:6: Filling up shuffle buffer (this may take a while): 866 of 1000
2025-03-11 20:16:36.317918: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:6: Filling up shuffle buffer (this may take a while)

    208/Unknown [1m204s[0m 174ms/step - loss: 0.0084 - mae: 0.0938

KeyboardInterrupt: 