In [17]:
from functions import *
import pickle
import time
from contextlib import redirect_stdout
import gc

# Load and augment prepared data

In [49]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import json

def augment_data(idx, inputs_label, reference, outlier_mask, W, additional_inputs=False, visualize=False, augment_iterations=1): 
    """
    Augments the data by applying transformations to the input images. Allows for multiple augmentation iterations.

    Args:
        idx (int): Index of the data point.
        inputs_label (tuple): Tuple containing the image and label (and additional inputs if applicable).
        reference (tuple): Reference data containing months and coordinates.
        outlier_mask (dict): Dictionary mapping (month, coords, label) to outlier status.
        W (int): Image width/height.
        additional_inputs (bool): Whether additional inputs are included.
        visualize (bool): Whether to return augmented outputs for visualization.
        augment_iterations (int): Number of augmentations to apply per image.

    Returns:
        tf.data.Dataset: Dataset containing augmented images and labels.
    """
    if additional_inputs:
        image = tf.cast(inputs_label[0][0], dtype=tf.float32)
        additional_inp = tf.cast(inputs_label[0][1], dtype=tf.float32)
        label = tf.cast(inputs_label[1], dtype=tf.float32)
    else:
        image = tf.cast(inputs_label[0], dtype=tf.float32)
        label = tf.cast(inputs_label[1], dtype=tf.float32)
    
    reference = (tf.cast(reference[0], dtype=tf.float32), tf.cast(reference[1], dtype=tf.float32))
    reference_month = tf.gather(reference[0], idx)
    reference_coords = tf.gather(reference[1], idx)

    reference_month = tf.expand_dims(reference_month, axis=0)  # Convert to [1]
    label = tf.expand_dims(label, axis=0) 

    def get_outlier_status(month, coords, lbl):
        python_key = (float(month), tuple(map(float, coords)), float(lbl))
        return outlier_mask.get(python_key, False)
    
    is_outlier = tf.py_function(
        func=lambda month, coords, lbl: get_outlier_status(month, coords, lbl),
        inp=[reference_month, reference_coords, label],
        Tout=tf.bool
    )

    # Generate base rotations
    base_rotations = [
        image,  # 0°
        tf.image.rot90(image, k=1),  # 90°
        tf.image.rot90(image, k=2),  # 180°
        tf.image.rot90(image, k=3),  # 270°
    ]

    base_mirrors = [
        tf.image.flip_left_right(image),  # Mirror left-to-right
        tf.image.flip_up_down(image),    # Mirror top-to-bottom
    ]

    all_base_transformations = base_rotations + base_mirrors
    
    def augment(transformations, iterations):
        """Generates multiple augmentations by repeating the transformations."""
        return transformations * iterations

    # Apply augmentations based on outlier status
    
    transformations = tf.cond(
        is_outlier,
        lambda: augment(all_base_transformations, max(2, augment_iterations * 2)),
        lambda: augment(all_base_transformations, augment_iterations) +  [tf.zeros_like(all_base_transformations[0]) for _ in range(len(all_base_transformations)*augment_iterations)]#[tf.zeros_like(all_base_transformations[0]) for _ in range(len(all_base_transformations*augment_iterations*2)-len(all_base_transformations))]#augment(all_base_transformations, 0) + [tf.zeros_like(all_base_transformations[0]) for _ in all_base_transformations]
    )


    valid_mask = tf.logical_not(tf.reduce_all(tf.equal(transformations, 0), axis=[1, 2, 3]))
    transformations = tf.boolean_mask(transformations, valid_mask)
    num_channels = tf.get_static_value(tf.shape(image)[-1])
    transformations = tf.map_fn(
                    lambda trans: tf.ensure_shape(trans, [W, W, num_channels]),
                    transformations
                )

    def pack_output(trans):
        if additional_inputs:
            return idx, (trans, tf.convert_to_tensor(additional_inp)), label
        else:
            return idx, trans, label

    if visualize:
        visualize_augmented_outputs = [
            {"image": trans, "label": label, "is_outlier": is_outlier} for trans in transformations
        ]
        return visualize_augmented_outputs
        
    augmented_dataset = tf.data.Dataset.from_tensor_slices(transformations).map(
        lambda trans: pack_output(trans),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    return augmented_dataset

def prepare_dataset(ds, target, conditioned, W=128, reference=None, augment=False, augment_iterations=1):
    """
    Prepares the dataset with optional augmentation.

    Args:
        ds (tuple): Tuple of input data.
        target (array): Target data.
        conditioned (bool): Whether additional inputs are included.
        W (int): Image width/height.
        reference (tuple): Reference data containing months and coordinates.
        augment (bool): Whether to apply augmentation.
        augment_iterations (int): Number of augmentations to apply per image.

    Returns:
        tf.data.Dataset: Prepared dataset.
    """
    inputs = (ds[0], ds[1]) if conditioned else ds
    dataset = tf.data.Dataset.from_tensor_slices((inputs, target))
    
    dataset_with_indices = dataset.enumerate() 
    
    if augment:
        with open('../data/external/outlier_dict.json', 'r') as json_file:
            loaded_dict = json.load(json_file)
        outlier_mask = {eval(key): value for key, value in loaded_dict.items()}

        def augment_with_index(idx, inputs_target):
            augmented = augment_data(
                idx, inputs_target, reference, outlier_mask, W,
                additional_inputs=conditioned, visualize=False, augment_iterations=augment_iterations
            )
            return augmented

        dataset_with_indices = dataset_with_indices.flat_map(
            augment_with_index
        )
    
    return dataset_with_indices


In [55]:
var_channels = {'lst': 0, 'ndvi': 1, 'slope': 2, 'altitude': 3, 'direction': 4}
var_position = {'month': [0, 1], 'coords': [2, 3], 'discharge': 4}
with open('../data/external/cos_to_month.pkl', 'rb') as file:
    cos_to_month = pickle.load(file)

inputs = ['lst', 'ndvi','slope','altitude','direction','month','coords','discharge']
augment = True
W = 128

train_dss = {}
for split in range(1,6):
    data_folder = f'../data/processed_data/{W}x{W}/{split}'
        
    train_model_input, train_additional_inputs, train_target = load_set(data_folder, inputs, 'train', var_channels, var_position)
    
    months = [cos_to_month[val] for val in train_additional_inputs[..., var_position['month'][0]]]
    reference = (months, train_additional_inputs[..., var_position['coords']]) 
    conditioned = len(train_model_input) == 2
    
    train_dataset = prepare_dataset(train_model_input, train_target, conditioned, W, reference, augment=True, augment_iterations=5)
    train_dss[split] = train_dataset
    count = train_dataset.reduce(0, lambda x, _: x + 1)
    print("Número total de elementos en el dataset:", count.numpy())


train model input shapes: [(784, 128, 128, 5), (784, 5)]
Número total de elementos en el dataset: 24330
train model input shapes: [(784, 128, 128, 5), (784, 5)]
Número total de elementos en el dataset: 24270
train model input shapes: [(784, 128, 128, 5), (784, 5)]
Número total de elementos en el dataset: 24330
train model input shapes: [(784, 128, 128, 5), (784, 5)]
Número total de elementos en el dataset: 24180
train model input shapes: [(784, 128, 128, 5), (784, 5)]
Número total de elementos en el dataset: 24060


In [5]:
for element in train_dataset.as_numpy_iterator():
    inputs = element[1:-1]  
    target = element[-1] 
    idx = element[0]
    
    if isinstance(inputs, tuple):  # Si inputs es una tupla
        print("Índice:", idx)
        print("Forma de inputs:", [(x[0].shape, x[1].shape) for x in inputs])  # Forma de cada input en la tupla
    else:
        print("Índice:", idx)
        print("Forma de inputs:", inputs.shape)  # Forma del input
    print("Forma de target:", target.shape)  # Forma del target
    break  # Solo inspecciona el primer elemento


Índice: 0
Forma de inputs: [((128, 128, 5), (5,))]
Forma de target: (1,)


In [74]:
count = train_dataset.reduce(0, lambda x, _: x + 1)
print("Número total de elementos en el dataset:", count.numpy())

Número total de elementos en el dataset: 982


## Analyze augmented data

In [42]:
def plot_augmented_outliers(outliers, idx):
    """
    Plots augmented outlier images in a grid.

    Args:
        outliers (list): List of augmented outlier dictionaries containing "image" keys.
        idx (int): Index of the data point being visualized.
    """
    num_samples = len(outliers)
    rows = 3  # Number of rows
    cols = 4  # Number of columns (adjust to match total samples)

    plt.figure(figsize=(15, rows * 5))  # Adjust height proportionally to rows
    plt.suptitle(f"Augmented Outlier Images - Index {idx}")

    for i, outlier in enumerate(outliers):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(outlier["image"][:, :, 0].numpy(), cmap="gray")  # Convert to NumPy and show
        plt.title(f"Rotation {i}")
        plt.axis("off")

    plt.tight_layout()
    #plt.savefig('../plots/augmented_outliers.png')
    plt.show()


In [50]:
import matplotlib.pyplot as plt
import pandas as pd

total_original = []
outs = []
no_outs = []
outs_augmented = []
no_outs_augmented = []
total_augmented = []
conditioned = False
with open('../data/external/cos_to_month.pkl', 'rb') as file:
    cos_to_month = pickle.load(file) 

for split in range(1,6):
    print('Split',split)
    W=64
    inputs = ['lst']
    data_folder = f'../data/processed_data/{W}x{W}/{split}'
    var_channels = {'lst': 0, 'ndvi': 1, 'slope': 2, 'altitude': 3, 'direction': 4}
    var_position = {'month': [0, 1], 'coords': [2, 3], 'discharge': 4}
    train_model_input, train_additional_inputs, train_target = load_set(data_folder, inputs, 'train', var_channels, var_position)
    val_model_input, validation_additional_inputs, validation_target = load_set(data_folder, inputs, 'validation', var_channels, var_position)
    
    months = [cos_to_month[val] for val in train_additional_inputs[..., var_position['month'][0]]]
    reference = (months, train_additional_inputs[..., var_position['coords']])
    
    outlier_mask = np.load('../data/external/outliers_mask.npy')
    with open('../data/external/outlier_dict.json', 'r') as json_file:
        loaded_dict = json.load(json_file)
    
    outlier_mask = {eval(key): value for key, value in loaded_dict.items()}
    
    ds = train_model_input
    target = train_target
    inputs = ds
    dataset = tf.data.Dataset.from_tensor_slices((inputs, target))
    dataset_with_index = dataset.enumerate()
    show=0
    show_no_outliers=0
    total_outliers = 0
    total_non_outliers = 0
    out=0
    no_out=0
    total=0
    
    for idx, inputs in dataset.enumerate():
        augmented_data = augment_data(idx, inputs, reference, outlier_mask, W,additional_inputs=conditioned, visualize=True,augment_iterations=1)
    
        # Dividir en outliers y normales
        outliers = [x for x in augmented_data if x["is_outlier"]]
        non_outliers = [x for x in augmented_data if not x["is_outlier"]]
    
        total_outliers += len(outliers)
        total_non_outliers += len(non_outliers)
    
        # Mostrar una muestra de outliers
        if outliers:
            out +=1
            if show<0:
                print(f"Index {idx}: Total augmentations for outlier: {len(outliers)}")
            
                # Mostrar todas las augmentaciones para este outlier
                plot_augmented_outliers(outliers, idx)
                '''
                plt.figure(figsize=(15, len(outliers)))
                plt.suptitle(f"Augmented Outlier Images - Index {idx}")
                
                for i, outlier in enumerate(outliers):
                    plt.subplot(2, 4, i + 1)
                    plt.imshow(outlier["image"][:,:,0].numpy(), cmap="gray")  # Convertir a NumPy y mostrar
                    plt.title(f"Rotation {i}")
                    plt.axis("off")
                plt.savefig('../plots/augmented_outliers.png')
                plt.show()'''
                show+=1
    
        # Mostrar una muestra de no outliers
        if non_outliers:
            no_out+=1
            if show_no_outliers<0:
                print(f"Non-Outlier Sample - Index {idx}")
                plt.figure(figsize=(12, 6))
                plt.suptitle("Augmented Non-Outlier Images")
                for i, non_outlier in enumerate(non_outliers[:4]):  # Máximo 4 augmentaciones
                    plt.subplot(1, 4, i + 1)
                    plt.imshow(non_outlier["image"][:,:,0].numpy(), cmap="gray")
                    plt.title(f"Non-Outlier Rotation {i}")
                    plt.axis("off")
                
                #plt.savefig('../plots/augmented_non_outliers.png')
                plt.show()
                show_no_outliers+=1
        total+=1

    
        # Rompe el bucle después de visualizar un ejemplo de cada tipo
        if outliers and non_outliers:
            break
            
    total_original.append(total)
    outs.append(out)
    no_outs.append(no_out)
    outs_augmented.append(total_outliers)
    no_outs_augmented.append(total_non_outliers)
    total_augmented.append(total_outliers + total_non_outliers)
    



Split 1
train model input shape: (784, 64, 64, 1)
validation model input shape: (261, 64, 64, 1)


2024-12-19 13:35:40.905063: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Split 2
train model input shape: (784, 64, 64, 1)
validation model input shape: (261, 64, 64, 1)
Split 3
train model input shape: (784, 64, 64, 1)
validation model input shape: (261, 64, 64, 1)
Split 4
train model input shape: (784, 64, 64, 1)
validation model input shape: (261, 64, 64, 1)
Split 5
train model input shape: (784, 64, 64, 1)
validation model input shape: (261, 64, 64, 1)


In [51]:
data = {'Original Total Count':total_original, 'Outliers Count':outs, 'Non Outliers Count':no_outs, 'Augmented Outliers Count':outs_augmented, 'Augmented Non-Outliers Count':no_outs_augmented, 'Augmented Total Count':total_augmented}
df=pd.DataFrame(data)
df

Unnamed: 0,Original Total Count,Outliers Count,Non Outliers Count,Augmented Outliers Count,Augmented Non-Outliers Count,Augmented Total Count
0,784,27,757,324,4542,4866
1,784,25,759,300,4554,4854
2,784,27,757,324,4542,4866
3,784,22,762,264,4572,4836
4,784,18,766,216,4596,4812


In [37]:
import pandas as pd

data = {
    'Original Total Count': total_original,
    'Outliers Count': outs,
    'Non Outliers Count': no_outs,
    'Augmented Outliers Count': outs_augmented,
    'Augmented Non-Outliers Count': no_outs_augmented,
    'Augmented Total Count': total_augmented
}

# Crear DataFrame
df = pd.DataFrame(data)
df.index.name = "Split"  # Añadir nombre al índice

# Dividir el DataFrame en dos partes
df_part1 = df[['Original Total Count', 'Outliers Count', 'Non Outliers Count']]
df_part2 = df[['Augmented Outliers Count', 'Augmented Non-Outliers Count', 'Augmented Total Count']]

# Generar el código LaTeX para ambas tablas
latex_part1 = df_part1.to_latex(index=True)
latex_part2 = df_part2.to_latex(index=True)

# Mostrar el código LaTeX
print("Tabla 1 (Parte 1):")
print(latex_part1)
print("\nTabla 2 (Parte 2):")
print(latex_part2)


Tabla 1 (Parte 1):
\begin{tabular}{lrrr}
\toprule
 & Original Total Count & Outliers Count & Non Outliers Count \\
Split &  &  &  \\
\midrule
0 & 784 & 27 & 757 \\
1 & 784 & 25 & 759 \\
2 & 784 & 27 & 757 \\
3 & 784 & 22 & 762 \\
4 & 784 & 18 & 766 \\
\bottomrule
\end{tabular}


Tabla 2 (Parte 2):
\begin{tabular}{lrrr}
\toprule
 & Augmented Outliers Count & Augmented Non-Outliers Count & Augmented Total Count \\
Split &  &  &  \\
\midrule
0 & 648 & 9084 & 9732 \\
1 & 600 & 9108 & 9708 \\
2 & 648 & 9084 & 9732 \\
3 & 528 & 9144 & 9672 \\
4 & 432 & 9192 & 9624 \\
\bottomrule
\end{tabular}



In [58]:
# 27*8*4
data = {'Original Count':total_original, 'Outliers Count':outs, 'Non Outliers Count':no_outs, 'Augmented Outliers Count':outs_augmented, 'Augmented Non-Outliers Count':no_outs_augmented, 'Augmented Total Count':total_augmented}
df=pd.DataFrame(data)
df

Unnamed: 0,Original Count,Outliers Count,Non Outliers Count,Augmented Outliers Count,Augmented Non-Outliers Count,Augmented Total Count
0,784,27,757,648,9084,9732
1,784,25,759,600,9108,9708
2,784,27,757,648,9084,9732
3,784,22,762,528,9144,9672
4,784,18,766,432,9192,9624


In [31]:
# 27*8*10
data = {'Original Count':total_original, 'Outliers Count':outs, 'Non Outliers Count':no_outs, 'Augmented Outliers Count':outs_augmented, 'Augmented Non-Outliers Count':no_outs_augmented, 'Augmented Total Count':total_augmented}
df=pd.DataFrame(data)
df

Unnamed: 0,Original Count,Outliers Count,Non Outliers Count,Augmented Outliers Count,Augmented Non-Outliers Count,Augmented Total Count
0,784,27,757,1080,3028,4108
1,784,25,759,1000,3036,4036
2,784,27,757,1080,3028,4108
3,784,22,762,880,3048,3928
4,784,18,766,720,3064,3784


In [34]:
data = {'Original Count':total_original, 'Outliers Count':outs, 'Non Outliers Count':no_outs, 'Augmented Outliers Count':outs_augmented, 'Augmented Non-Outliers Count':no_outs_augmented, 'Augmented Total Count':total_augmented}
df=pd.DataFrame(data)
df

Unnamed: 0,Original Count,Outliers Count,Non Outliers Count,Augmented Outliers Count,Augmented Non-Outliers Count,Augmented Total Count
0,784,27,757,1512,3028,4540
1,784,25,759,1400,3036,4436
2,784,27,757,1512,3028,4540
3,784,22,762,1232,3048,4280
4,784,18,766,1008,3064,4072


# Save prepared augmented datasets

In [15]:
import numpy as np

def save_as_numpy(dataset, file_path, conditioned=False):
    indices, inputs_list, additional_inputs_list, targets = [], [], [], []

    for element in dataset.as_numpy_iterator():
        if conditioned:
            inputs_data = element[1]  
            inputs, additional_inputs = inputs_data
            target = element[2] 
            index = element[0]
        else:
            # Correct unpacking for non-augmented datasets
            index, (inputs_data, target) = element
            inputs = inputs_data[0]
            additional_inputs = None  # No additional inputs in non-augmented datasets

        # Collect data
        indices.append(index)
        inputs_list.append(inputs)
        if additional_inputs is not None:
            additional_inputs_list.append(additional_inputs)
        targets.append(target)

    # Convert to NumPy arrays
    indices = np.array(indices)
    inputs = np.array(inputs_list)
    additional_inputs = (
        np.array(additional_inputs_list) if additional_inputs_list else None
    )
    targets = np.array(targets)
    if conditioned:
        print(len(indices), inputs[0].shape,len(additional_inputs), len(additional_inputs[0]),additional_inputs[0].shape ,targets[0].shape)
    else:
        print(len(indices), inputs[0].shape,targets[0].shape)

    # Save to .npz
    np.savez(
        file_path,
        indices=indices,
        inputs=inputs,
        additional_inputs=additional_inputs,
        targets=targets,
    )
    print(f"Dataset saved to {file_path}.npz")


In [56]:
# Save training, validation, and test datasets
dest_dir = "../data/processed_data/augmented_all_5x"
os.makedirs(dest_dir, exist_ok = True)
for split, train_ds in train_dss.items():
    split_folder = f"{dest_dir}/{split}"
    os.makedirs(split_folder, exist_ok = True)
    save_as_numpy(train_ds, f"{split_folder}/train_dataset", conditioned=True)

24330 (128, 128, 5) 24330 5 (5,) (1,)
Dataset saved to ../data/processed_data/augmented_all_5x/1/train_dataset.npz
24270 (128, 128, 5) 24270 5 (5,) (1,)
Dataset saved to ../data/processed_data/augmented_all_5x/2/train_dataset.npz
24330 (128, 128, 5) 24330 5 (5,) (1,)
Dataset saved to ../data/processed_data/augmented_all_5x/3/train_dataset.npz
24180 (128, 128, 5) 24180 5 (5,) (1,)
Dataset saved to ../data/processed_data/augmented_all_5x/4/train_dataset.npz
24060 (128, 128, 5) 24060 5 (5,) (1,)
Dataset saved to ../data/processed_data/augmented_all_5x/5/train_dataset.npz


## Load for checks

In [3]:
def load_numpy_dataset(file_path):
    data = np.load(file_path + ".npz")
    indices = data["indices"]
    inputs = data["inputs"]
    additional_inputs = data['additional_inputs']
    print(inputs.shape,additional_inputs.shape)
    targets = data["targets"]
    return tf.data.Dataset.from_tensor_slices((indices, ((inputs,additional_inputs), targets)))


In [4]:
# Load datasets
dest_dir = "../data/processed_data/augmented"

train_dataset = load_numpy_dataset(f"{dest_dir}/train_dataset")

(3244, 128, 128, 5) (3244, 5)


2024-12-18 10:21:53.076350: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10530 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:02:00.0, compute capability: 6.1
2024-12-18 10:21:53.077019: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 7000 MB memory:  -> device: 1, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:81:00.0, compute capability: 6.1


In [11]:
def select_channels(dataset, channels_to_keep):
    """
    Preprocess the dataset to keep only specified channels from the inputs.

    Args:
    - dataset: tf.data.Dataset with structure (index, ((image_inputs, additional_inputs), target)).
    - channels_to_keep: List of channel indices to keep (e.g., [0, 2]).

    Returns:
    - A preprocessed tf.data.Dataset.
    """
    def preprocess(index, inputs_targets):
        inputs, target = inputs_targets
        image_inputs, additional_inputs = inputs

        # Select specific channels
        selected_image_inputs = tf.gather(image_inputs, channels_to_keep, axis=-1)

        # Return the new structure
        return index, ((selected_image_inputs, additional_inputs), target)

    return dataset.map(preprocess)


In [12]:
channels_to_keep = [0, 2]  # Example: Keep channels 'lst' and 'slope'
#train_dataset = select_channels(train_dataset, channels_to_keep)


In [13]:
for idx, ((image_inputs, additional_inputs), target) in train_dataset.take(1):
    print("Index:", idx.numpy())
    print("Image input shape after selecting channels:", image_inputs.shape)  # Should be (128, 128, len(channels_to_keep))
    print("Additional inputs shape:", additional_inputs.shape)  # Should remain (5,)
    print("Target shape:", target.shape)

Index: 0
Image input shape after selecting channels: (128, 128, 5)
Additional inputs shape: (5,)
Target shape: (1,)


2024-12-17 16:35:48.052977: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [5]:
loss_type = 'RMSE'

In [8]:
batch_size = 64
val_model_input, validation_additional_inputs, validation_target = load_set(data_folder, inputs, 'validation', var_channels, var_position)
val_dataset = prepare_dataset(val_model_input, validation_target, conditioned, W, augment=False)
train_dataset = train_dataset.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
model_name = "simple_CNN"
epochs = 300
augment = True
conditioned = True

validation model input shapes: [(261, 128, 128, 5), (261, 5)]


In [None]:
# Function to inspect dataset contents
def inspect_dataset(dataset, num_samples=5):
    print("Inspecting dataset...")
    count = 0
    for element in dataset.take(num_samples):  # Limit to `num_samples` for inspection
        print(len(element))
        idx = element[0]
        inputs = element[1]
        inp = inputs[0][0]
        inp2 = inputs[0][1]
        target = inputs[1]

        # Print shapes and types
        print(f"Sample {count + 1}:")
        print(f"  Index shape: {idx.shape}, Value: {idx.numpy()}")
        print(f"  Input 1 shape: {inp.shape}, Type: {inp.dtype}")
        print(f"  Input 2 shape: {inp2.shape}, Type: {inp2.dtype}")
        print(f"  Target shape: {target.shape}, Value: {target.numpy()}")
        print("-" * 40)

        count += 1

    print(f"Total samples inspected: {count}")

# Inspect train and validation datasets
inspect_dataset(train_dataset, num_samples=1)


In [10]:
for b in train_dataset.take(1):
    print(len(b), len(b[1]), len(b[1][0]), len(b[1][0][0]),b[1][0][0].shape, b[1][0][1].shape)

3 2 64 128 (128, 128, 5) (128, 128, 5)


2024-12-18 10:22:24.205111: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-12-18 10:22:24.205340: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [11]:
for b in val_dataset.take(1):
    print(len(b), len(b[1]), len(b[1][0]), len(b[1][0][0]),b[1][0][0].shape, b[1][0][1].shape)

2 2 2 64 (64, 128, 128, 5) (64, 5)


2024-12-18 10:22:24.955198: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-12-18 10:22:24.958813: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [9]:
# Start model
conditioned = False
for idx, ((image_inputs, additional_inputs), target) in train_dataset.take(1):
    image_inp_shape = tuple(image_inputs.shape.as_list()[1:])
    additional_inp_shape = tuple(additional_inputs.shape.as_list()[1:])
    print("Image input shape after selecting channels:", image_inp_shape)  # Should be (128, 128, len(channels_to_keep))
    print("Additional inputs shape:", additional_inp_shape)  # Should remain (5,)
    print("Target shape:", target.shape)
    
if conditioned:
    input_args = (image_inp_shape, additional_inp_shape)
else:
    input_args = image_inp_shape
    
model = build_model_map(model_name, input_args, W)
start_time = time.time()
    
summary_file = f"../models/{model_name}_summary.txt"
with open(summary_file, "w") as f:
    with redirect_stdout(f):
        model.summary()

# Set hyperparmeters variables
#initial_lr = 0.01
#lr_schedule = ExponentialDecay(initial_lr, decay_steps=50, decay_rate=0.96, staircase=True)
#optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9, nesterov=True)
#optimizer = tf.keras.optimizers.SGD()
optimizer = tf.keras.optimizers.Adam()
errors_log = {"epoch": [], "month": [], "error": []}
loss_per_epoch = []
val_loss_per_epoch = []

# Early Stopping parameters
patience = 30  # Number of epochs with no improvement before stopping
min_delta = 1e-4  # Minimum improvement required to consider progress
best_val_loss = float('inf')  # Best observed validation loss
wait = 0  # Counter for epochs without improvement

# Train model
for epoch in range(epochs):
    epoch_loss = 0  
    num_batches = 0
    train_dataset = train_dataset.shuffle(buffer_size=len(train_model_input)).prefetch(tf.data.AUTOTUNE)
    
    for batch in train_dataset:
        if augment:
            if conditioned:
                model_input_batch = batch[1][0]  
                target_batch = batch[1][1]
                idx = batch[0]
            else:
                model_input_batch = batch[1][0][0] 
                target_batch = batch[1][1]
                idx = batch[0]
                
        else:
            if conditioned:
                model_input_batch = batch[1][0]
                target_batch = batch[1][1]  
                idx = batch[0]
            else:
                idx, batch_data = batch
                model_input_batch = batch_data[:-1]
                target_batch = batch_data[-1]
        
        with tf.GradientTape() as tape:
            # Forward pass
            #y_pred = model([*model_input_batch], training=True) if conditioned else model(model_input_batch, training=True)
            if conditioned:
                
                print("Shape of model_input_batch[0]:", model_input_batch[0].shape)  # Images
                print("Shape of model_input_batch[1]:", model_input_batch[1].shape)  # Additional scalar inputs
                y_pred = model([model_input_batch[0],model_input_batch[1]], training = True)

            else:
                y_pred = model(model_input_batch, training=True)
            
            # Compute loss based on the selected method
            if loss_type == 'Physics_guided':
                lst_batch = model_input_batch[0][:, :, :, :0] if conditioned else model_input_batch[:, :, :, 0]
                loss = conservation_energy_loss(target_batch, y_pred, lst_batch, alpha=0.5, beta=0.5)
            elif loss_type == 'RMSE_sensitive':
                loss = rmse_extreme_sensitive(target_batch, y_pred, k1=0.01, k2=1.0, alpha=1.0)
            elif loss_type == 'RMSE_focal':
                loss = rmse_focal(target_batch, y_pred, gamma=1.0)
            else:
                loss = root_mean_squared_error(target_batch, y_pred) 
        # Calculate gradients and apply optimization
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        epoch_loss += loss.numpy()
        num_batches += 1
        
        # Log variables values and error
        y_true = tf.cast(target_batch, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        current_batch_size = y_true.shape[0]
        batch_cosine_values = train_additional_inputs[:, 0][idx]
    
        # Log RMSE values for each prediction
        for cos, pred, true in zip(batch_cosine_values, y_pred, y_true):
            squared_error = tf.square(pred - true) 
            rmse_sample = tf.sqrt(squared_error)  # RMSE 
            rmse_value = rmse_sample.numpy()
            errors_log["epoch"].append(epoch + 1)
            errors_log["month"].append(cos_to_month[cos])
            errors_log["error"].append(rmse_value)
                
    avg_epoch_loss = epoch_loss / num_batches
    loss_per_epoch.append(avg_epoch_loss)
    
    # Validation loss
    val_loss = 0
    val_batches = 0
    
    for val_batch in val_dataset:
        if conditioned:
            val_input_batch = val_batch[1][0]
            val_target_batch = val_batch[1][1]  
            idx = val_batch[0]
        else:
            val_input_batch = val_batch[1][0][0] 
            val_target_batch = val_batch[1][1]  
            idx = val_batch[0]
            
        if conditioned:
            print("Shape of model_input_batch[0]:", val_input_batch[0].shape)  # Images
            print("Shape of model_input_batch[1]:", val_input_batch[1].shape)  # Additional scalar inputs
            val_pred = model([val_input_batch[0],val_input_batch[1]], training = False)

        else:
            val_pred = model(val_input_batch, training=False)
            
        #val_pred = model([*val_input_batch], training=False) if conditioned else model(val_input_batch, training=False)
        val_loss += root_mean_squared_error(val_target_batch, val_pred).numpy()
        val_batches += 1
        
    avg_val_loss = val_loss / val_batches
    val_loss_per_epoch.append(avg_val_loss)

    print(f"Epoch {epoch + 1}/{epochs} - Loss: {avg_epoch_loss:.4f} - Val Loss: {avg_val_loss:.4f}")

    # Early Stopping Logic
    if avg_val_loss < best_val_loss - min_delta:
        # Update the best validation loss and reset patience counter
        best_val_loss = avg_val_loss
        wait = 0  
        print(f"Validation loss improved to {best_val_loss:.4f}.")
    else:
        # Increment patience counter
        wait += 1
        print(f"No improvement in validation loss for {wait} epochs.")
        if wait >= patience:
            # Stop training if patience threshold is exceeded
            print(f"Stopping early at epoch {epoch + 1}.")
            break
    
    gc.collect()  # Free up memory after each epoch



2024-12-18 10:22:13.916711: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


ValueError: too many values to unpack (expected 2)