In [None]:
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
import xarray as xr
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, Dense
from sklearn.preprocessing import StandardScaler

# Define the base directory and file patterns
base_dir = '/content/drive/MyDrive/hakData/'
file_patterns = {
    'nitrate': 'woa13_all_n01_01.nc',
    'phosphate': 'woa13_all_p01_01.nc',
    'silicate': 'woa13_all_i01_01.nc'
}

# Initialize dictionaries to hold the data
lat = None
lon = None
nutrient_data = {}

# Read the data for each nutrient
for nutrient, pattern in file_patterns.items():
    ds = xr.open_dataset(base_dir + pattern, engine='netcdf4', decode_times=False)

    # Extract latitude and longitude only once (assuming they are the same for all nutrients)
    if lat is None and lon is None:
        lat = ds['lat'].values
        lon = ds['lon'].values

    if nutrient == 'nitrate':
        nutrient_data['nitrate'] = ds['n_an'].values[0, 0, :, :]  # Extract surface layer
    elif nutrient == 'phosphate':
        nutrient_data['phosphate'] = ds['p_an'].values[0, 0, :, :]  # Extract surface layer
    elif nutrient == 'silicate':
        nutrient_data['silicate'] = ds['i_an'].values[0, 0, :, :]  # Extract surface layer

# Combine the nutrient data into a single 3x180x360 array
combined_data = np.stack((nutrient_data['nitrate'], nutrient_data['phosphate'], nutrient_data['silicate']), axis=0)

# Transpose combined_data to shape (180, 360, 3) to get (lat, lon, features)
data = combined_data.transpose((1, 2, 0))

In [None]:
# Separate inputs (phosphate, silicate) and output (nitrate)
X = data[:, :, 1:]  # Input features: phosphate, silicate
y = data[:, :, 0]   # Output: nitrate

# Reshape X to have shape (180*360, 2) and y to (180*360, 1) for compatibility with the RNN
X = X.reshape((180 * 360, 2))
y = y.reshape((180 * 360, 1))

# Combine X and y to filter out rows with nan values
data_combined = np.hstack((X, y))

# Remove rows with nan values
data_combined = data_combined[~np.isnan(data_combined).any(axis=1)]

# Separate the data back into X and y
X = data_combined[:, :2]
y = data_combined[:, 2:]

# Add the dummy time dimension to X
X = X.reshape((X.shape[0], 1, X.shape[1]))

# Normalize the input features (phosphate and silicate)
scaler = StandardScaler()
X = scaler.fit_transform(X.reshape(-1, 2)).reshape(X.shape)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Print the shapes to verify
print(f'X_train shape: {X_train.shape}, y_train shape: {y_train.shape}')
print(f'X_test shape: {X_test.shape}, y_test shape: {y_test.shape}')



X_train shape: (32870, 1, 2), y_train shape: (32870, 1)
X_test shape: (8218, 1, 2), y_test shape: (8218, 1)


In [None]:
# Define the RNN model
model = Sequential()
model.add(SimpleRNN(50, activation='relu', input_shape=(1, 2)))
model.add(Dense(1))

# Compile the model with a lower learning rate
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')

# Print the model summary
model.summary()

# Train the model
history = model.fit(X_train, y_train, epochs=50, batch_size=16, validation_data=(X_test, y_test))

# Plot the training and validation loss
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 simple_rnn_2 (SimpleRNN)    (None, 50)                2650      
                                                                 
 dense_2 (Dense)             (None, 1)                 51        
                                                                 
Total params: 2,701
Trainable params: 2,701
Non-trainable params: 0
_________________________________________________________________
Epoch 1/50
 392/2055 [====>.........................] - ETA: 5s - loss: 63.4082

KeyboardInterrupt: 