In [None]:
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import rc
from netCDF4 import Dataset
from IPython.display import HTML, display
rc('animation', html='jshtml')
import sys

In [None]:
def update_concentration_parallel(ocean_map, land_mask, S, Dx, Dy, u, v, dt, dx, dy, iterations):
    rows, cols = ocean_map.shape
    maps_over_time = [np.copy(ocean_map)]
    for iteration in range(iterations):
        C_current = ocean_map

        # Initialize C_new using the updated formula
        u_shift_x = np.where(u > 0, 1, -1)
        v_shift_y = np.where(v > 0, 1, -1)

        # Calculate the mask for valid neighbors
        valid_u_mask = np.where(np.roll(land_mask, -1, axis=0) == False, 1, 0)
        valid_v_mask = np.where(np.roll(land_mask, -1, axis=1) == False, 1, 0)

        # Calculate C_new with the updated rule
        C_new = (
            (1 - 2 * dt * Dx / dx**2 - 2 * dt * Dy / dy**2)
            - (np.abs(u) * valid_u_mask * dt / dx)
            - (np.abs(v) * valid_v_mask * dt / dy)
        ) * C_current

        # Apply the pollution source term during the first half of iterations
        pollution_source = np.where(iteration < iterations / 2, S * dt, 0)
        C_new += pollution_source

        # Calculate contributions from neighboring cells (same logic as before)
        mask_x_positive = (np.roll(land_mask, -1, axis=0) == False) & ((np.arange(rows) + 1 < rows)[:, np.newaxis])
        u_x_positive = np.roll(u, -1, axis=0)  # u[i+1] for x+ direction
        C_new = C_new + (dt / dx) * (Dx / dx - np.where(u_x_positive < 0, u_x_positive, 0)) * np.roll(ocean_map, -1, axis=0) * mask_x_positive

        mask_x_negative = (np.roll(land_mask, 1, axis=0) == False) & ((np.arange(rows) - 1 >= 0)[:, np.newaxis])
        u_x_negative = np.roll(u, 1, axis=0)  # u[i-1] for x- direction
        C_new = C_new + (dt / dx) * (Dx / dx + np.where(u_x_negative > 0, u_x_negative, 0)) * np.roll(ocean_map, 1, axis=0) * mask_x_negative

        mask_y_positive = (np.roll(land_mask, -1, axis=1) == False) & ((np.arange(cols) + 1 < cols)[np.newaxis, :])
        v_y_positive = np.roll(v, -1, axis=1)  # v[j+1] for y+ direction
        C_new = C_new + (dt / dy) * (Dy / dy - np.where(v_y_positive < 0, v_y_positive, 0)) * np.roll(ocean_map, -1, axis=1) * mask_y_positive

        mask_y_negative = (np.roll(land_mask, 1, axis=1) == False) & ((np.arange(cols) - 1 >= 0)[np.newaxis, :])
        v_y_negative = np.roll(v, 1, axis=1)  # v[j-1] for y- direction
        C_new = C_new + (dt / dy) * (Dy / dy + np.where(v_y_negative > 0, v_y_negative, 0)) * np.roll(ocean_map, 1, axis=1) * mask_y_negative

        # Update the ocean_map
        ocean_map = np.where(land_mask, np.max(C_new), C_new)

        if iteration % 10 == 0:
            maps_over_time.append(np.copy(ocean_map))
    return maps_over_time


def animate_concentration(maps_over_time):
    fig, ax = plt.subplots(figsize=(6, 6))
    cmap = 'viridis'
    cax = ax.imshow(maps_over_time[0], cmap=cmap, origin='lower', vmax=10)
    fig.colorbar(cax, ax=ax, label='Concentration')

    def update(frame):
        cax.set_data(maps_over_time[frame])
        ax.set_title(f'Concentration Map - Iteration {frame + 1}')

    return animation.FuncAnimation(fig, update, frames=len(maps_over_time), interval=1)

In [None]:
%%time

# Load current velocity data
velocity_dataset = Dataset("./CMEMS_horizontal_current_velocity_data.nc")
uo_data = np.array(velocity_dataset.variables['uo'][:])  # Shape: (time, depth, latitude, longitude)
vo_data = np.array(velocity_dataset.variables['vo'][:])  # Shape: (time, depth, latitude, longitude)
velocity_dataset.close()

# Preprocess data
uo_data = np.squeeze(uo_data)
vo_data = np.squeeze(vo_data)
uo_data = uo_data[:, :240, :]  # Trim off one excess layer
vo_data = vo_data[:, :240, :]  # Trim off one excess layer

# Use data for u and v arrays
u, v = uo_data[0], vo_data[0]  # TODO: for each time step, use the next time slice of u and v

# Parameters
rows, cols = u.shape[0], u.shape[1]  # Map size (larger to better visualize)
dt = 3600  # Time step
dx, dy = 8000, 8000  # Spatial step sizes
iterations = 500  # Number of iterations

# Ignore diffusion terms for now
Dx = Dy = 1500 * np.ones((rows, cols))  # TODO: add diffusivity by incorporating temperature data and solving Stokes-Einstein

# Generate initial ocean map
ocean_map = np.zeros((rows, cols))

# Generate dirichlet boundary conditions using velocity data (NaNs in data represent land)
land_mask = np.zeros((rows, cols), dtype=bool)
missing_uo_indices = np.isnan(uo_data[0])
land_mask[missing_uo_indices] = True

u[np.isnan(u)] = 0
v[np.isnan(v)] = 0
# Create a source at the mouth of the Mississippi river: 29.1511 N, -89.2533 W. This is roughly index (206, 153)
S = np.zeros((rows, cols))
S[205, 152] = 100  # This value is on the water: (206, 153) is on land



In [None]:
# Update concentration and collect maps over time

maps_over_time = update_concentration_parallel(ocean_map, land_mask, S, Dx, Dy, u, v, dt, dx, dy, iterations)  # TODO: fix update_concentration function

# Animate the concentration map
anim = animate_concentration(maps_over_time)
display(HTML(anim.to_jshtml()))

In [None]:
plt.figure(figsize=(10, 6))

plt.imshow(land_mask, cmap="Greys", interpolation="nearest")
plt.colorbar(label="Land/Sea")
plt.gca().invert_yaxis()

plt.scatter(153, 206, color="red", s=10)

plt.show()