In [None]:
# Load the trained model
model = load_model('trained_model.h5')

# Load the NetCDF file
file_path = 'satellite_and_WOA13_1_degree_Jan_v2.nc'
data = xr.open_dataset(file_path)

# Extract the features and target variables
features = ['CHL', 'APH', 'FLU', 'PIC', 'POC', 'PAR', 'SST']
targets = ['nitrate', 'phosphate', 'silicate']

# Create a DataFrame for features
dfs = []
for var in features:
    df = data[var].to_dataframe().reset_index()
    df = df[['lat', 'lon', var]]
    dfs.append(df)
feature_data = dfs[0]
for df in dfs[1:]:
    feature_data = pd.merge(feature_data, df, on=['lat', 'lon'])
feature_data.columns = ['lat', 'lon'] + features

# Handle missing values by filling them with the mean value of each column
feature_data = feature_data.fillna(feature_data.mean())

# Create a coarser grid of latitude and longitude values
lats = data['lat'].values[::5]  # Downsample by a factor of 5
lons = data['lon'].values[::5]  # Downsample by a factor of 5
lon_grid, lat_grid = np.meshgrid(lons, lats)

# Prepare the input data for prediction
input_data = np.full((lat_grid.shape[0], lat_grid.shape[1], len(features)), np.nan)
for i in range(lat_grid.shape[0]):
    for j in range(lat_grid.shape[1]):
        lat, lon = lat_grid[i, j], lon_grid[i, j]
        row = feature_data.loc[(feature_data['lat'] == lat) & (feature_data['lon'] == lon), features]
        if not row.empty:
            input_data[i, j] = row.values[0]

# Filter out rows with NaN values
mask = ~np.isnan(input_data).any(axis=2)
valid_input_data = input_data[mask]

# Make predictions for valid data
predictions = model.predict(valid_input_data)

# Initialize arrays to hold predictions with NaNs for invalid points
nitrate_pred = np.full(lon_grid.shape, np.nan)
phosphate_pred = np.full(lon_grid.shape, np.nan)
silicate_pred = np.full(lon_grid.shape, np.nan)

# Assign predictions to the respective grid points
nitrate_pred[mask] = predictions[:, 0]
phosphate_pred[mask] = predictions[:, 1]
silicate_pred[mask] = predictions[:, 2]

# Create a colormap with white for NaN values
cmap = plt.cm.viridis
cmap.set_bad(color='white')

# Function to plot a world map of predictions
def plot_world_map(data, title, cmap):
    fig = plt.figure(figsize=(12, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.set_global()
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.LAND, edgecolor='black', facecolor='gray')
    ax.add_feature(cfeature.OCEAN)
    ax.add_feature(cfeature.LAKES, edgecolor='black')
    ax.add_feature(cfeature.RIVERS)

    # Plot the data with masked NaNs
    plt.pcolormesh(lon_grid, lat_grid, np.ma.masked_invalid(data), transform=ccrs.PlateCarree(), cmap=cmap)
    plt.colorbar(label='Concentration')
    plt.title(title)
    plt.show()

# Plot the predictions
plot_world_map(nitrate_pred, 'Predicted Nitrate Concentration', cmap)
plot_world_map(phosphate_pred, 'Predicted Phosphate Concentration', cmap)
plot_world_map(silicate_pred, 'Predicted Silicate Concentration', cmap)