In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import gc

# ✅ Step 1: Use np.memmap to load .npy files without RAM overload
print("Mapping datasets to memory...")
X_mass = np.load("/home/amanbh/projects/tf217/DISS/Maps_Mcdm_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
X_vel  = np.load("/home/amanbh/projects/tf217/DISS/Maps_Vcdm_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
Y_gas  = np.load("/home/amanbh/projects/tf217/DISS/Maps_Mgas_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')
Y_temp = np.load("/home/amanbh/projects/tf217/DISS/Maps_T_IllustrisTNG_LH_z=0.00.npy", mmap_mode='r')

total_samples = len(X_mass)
print(f"Total samples: {total_samples}")  # Check dataset size

   # Pre-compute dataset statistics once (on a subset if needed)
   def get_dataset_stats(data, n_samples=100):
       min_vals = []
       max_vals = []
       for i in range(min(n_samples, len(data))):
           min_vals.append(np.min(data[i]))
           max_vals.append(np.max(data[i]))
       return np.min(min_vals), np.max(max_vals)
   
   # Then normalize consistently across all batches
   X_mass_min, X_mass_max = get_dataset_stats(X_mass)
   # Apply these consistently in normalize_batch function

# ✅ Step 2: Efficient Normalization (applied dynamically per batch)
def normalize_batch(batch, epsilon=1e-6):
    min_val = np.min(batch)
    max_val = np.max(batch)
    return np.log1p(batch - min_val + epsilon) / np.log1p(max_val - min_val + epsilon)

# ✅ Step 3: Create a Memory-Efficient Data Generator with indices
def data_generator(start, end):
    for i in range(start, end):
        X1 = normalize_batch(X_mass[i])  # Dark matter mass
        X2 = normalize_batch(X_vel[i])   # Dark matter velocity
        Y1 = normalize_batch(Y_gas[i])   # Gas density
        Y2 = normalize_batch(Y_temp[i])  # Temperature
        yield (np.stack([X1, X2], axis=-1), np.stack([Y1, Y2], axis=-1))

# ✅ Step 4: Split indices for training, validation, and test
train_fraction = 0.8
train_samples = int(total_samples * train_fraction)
remaining_samples = total_samples - train_samples
# Split remaining samples equally for validation and test
val_samples = int(remaining_samples * 0.5)
test_samples = remaining_samples - val_samples

print(f"Training samples: {train_samples}, Validation samples: {val_samples}, Test samples: {test_samples}")

# ✅ Step 5: Create TensorFlow Datasets for Training, Validation, and Test
train_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(0, train_samples),
    output_signature=(
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32),  # Adjust shape if needed
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32)
    )
)
train_dataset = train_dataset.shuffle(1000).batch(8).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(train_samples, train_samples + val_samples),
    output_signature=(
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32),
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32)
    )
)
val_dataset = val_dataset.batch(8).prefetch(tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(train_samples + val_samples, total_samples),
    output_signature=(
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32),
        tf.TensorSpec(shape=(256, 256, 2), dtype=tf.float32)
    )
)
test_dataset = test_dataset.batch(8).prefetch(tf.data.AUTOTUNE)

# ✅ Step 6: Define U-Net Model (or any model you want)
def unet_model(input_shape):
    inputs = tf.keras.layers.Input(input_shape)

    # Encoder
    conv1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    # Bottleneck
    conv4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)

    # Decoder
    up1 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv4)
    up1 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up1)
    merge1 = tf.keras.layers.Concatenate()([up1, conv3])

    up2 = tf.keras.layers.UpSampling2D(size=(2, 2))(merge1)
    up2 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up2)
    merge2 = tf.keras.layers.Concatenate()([up2, conv2])

    up3 = tf.keras.layers.UpSampling2D(size=(2, 2))(merge2)
    up3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up3)
    merge3 = tf.keras.layers.Concatenate()([up3, conv1])

    outputs = tf.keras.layers.Conv2D(2, (1, 1), activation='sigmoid')(merge3)

    # Compile with MSE loss; also track MAE and Huber loss metrics.
    model = tf.keras.models.Model(inputs, outputs)
    model.compile(
        optimizer='adam',
        loss='mse',
        metrics=[
            'mae',
            tf.keras.metrics.MeanMetricWrapper(tf.keras.losses.Huber(), name='huber_loss')
        ]
    )
    return model

# Initialize the model
model = unet_model((256, 256, 2))
model.summary()

# ✅ Step 7: Setup Checkpointing and Resume Training Capability
checkpoint_dir = './checkpoints'
checkpoint_path = os.path.join(checkpoint_dir, 'cp-{epoch:04d}.ckpt')

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=False,  # Save full model
    save_freq='epoch'
)

csv_logger = tf.keras.callbacks.CSVLogger('training_log.csv', append=True)

# EarlyStopping callback: Stop training if validation loss doesn't improve for 10 epochs.
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True,
    verbose=1
)

# Check if there is an existing checkpoint to resume from
latest = tf.train.latest_checkpoint(checkpoint_dir)
initial_epoch = 0
if latest:
    print("Found checkpoint at", latest)
    model = tf.keras.models.load_model(latest)
    initial_epoch = int(latest.split('-')[-1].split('.')[0])
    print(f"Resuming training from epoch {initial_epoch}")

# ✅ Step 8: Create a Callback for Live Plotting of Epoch vs Loss, MAE, and Huber Loss
class LivePlotCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super(LivePlotCallback, self).__init__()
        self.epochs = []
        self.losses = []
        self.mae = []
        self.huber = []
        plt.ion()  # Enable interactive mode
        # Create 3 subplots for Loss, MAE, and Huber loss
        self.figure, (self.ax1, self.ax2, self.ax3) = plt.subplots(1, 3, figsize=(18, 5))
        self.line1, = self.ax1.plot([], [], 'r-', label='Loss')
        self.line2, = self.ax2.plot([], [], 'b-', label='MAE')
        self.line3, = self.ax3.plot([], [], 'g-', label='Huber Loss')
        self.ax1.set_xlabel('Epoch')
        self.ax1.set_ylabel('Loss')
        self.ax2.set_xlabel('Epoch')
        self.ax2.set_ylabel('MAE')
        self.ax3.set_xlabel('Epoch')
        self.ax3.set_ylabel('Huber Loss')
        self.ax1.legend()
        self.ax2.legend()
        self.ax3.legend()
        plt.show()

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs.append(epoch + 1)
        self.losses.append(logs.get('loss'))
        self.mae.append(logs.get('mae'))
        self.huber.append(logs.get('huber_loss'))
        # Update Loss plot
        self.line1.set_data(self.epochs, self.losses)
        self.ax1.set_xlim(0, max(self.epochs) + 1)
        self.ax1.set_ylim(0, max(self.losses) * 1.1)
        # Update MAE plot
        self.line2.set_data(self.epochs, self.mae)
        self.ax2.set_xlim(0, max(self.epochs) + 1)
        self.ax2.set_ylim(0, max(self.mae) * 1.1)
        # Update Huber Loss plot
        self.line3.set_data(self.epochs, self.huber)
        self.ax3.set_xlim(0, max(self.epochs) + 1)
        self.ax3.set_ylim(0, max(self.huber) * 1.1)
        self.figure.canvas.draw()
        self.figure.canvas.flush_events()

live_plot_callback = LivePlotCallback()

# ✅ Step 9: Train the Model with Callbacks and Validation Data
model.fit(
    train_dataset,
    epochs=100,              # Total number of epochs to run (subject to early stopping)
    initial_epoch=initial_epoch,
    validation_data=val_dataset,
    callbacks=[cp_callback, csv_logger, live_plot_callback, early_stop]
)

# ✅ Step 10: Evaluate the Model on Test Data
test_loss, test_mae, test_huber = model.evaluate(test_dataset)
print(f"Test Loss: {test_loss:.4f}, Test MAE: {test_mae:.4f}, Test Huber Loss: {test_huber:.4f}")
