In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

import mantaray

In [None]:
g = 9.8
def period2wavenumber(T):
    k = (2*np.pi)**2/g/T**2
    return k

def group_velocity(k):
    cg = 0.5*(g/k)**.5
    return cg

In [None]:
# Period of incident waves in seconds
T0 = 10
# Direction of incident waves in radians (trig convention, going to)
theta0 = 0
# Convert period to wavenumber magnitude
k0 = period2wavenumber(T0)
# Calculate wavenumber components
kx0 = k0*np.cos(theta0)
ky0 = k0*np.sin(theta0)

# Number of rays
n_rays = 100
# Initialize wavenumber for all rays
Kx0 = kx0*np.ones(n_rays)
Ky0 = ky0*np.ones(n_rays)

# Current and bathymetry file path
current = 'data/currents/zonal_jet.nc'
bathymetry = 'data/bathymetry/bathy_zonal_jet.nc'

# Read x and y from file to get domain size
ds = xr.open_dataset(current)
x = ds.x.values
y = ds.y.values

# Creates initial x position for all rays
x0 = 10*np.ones(n_rays)
y0 = np.linspace(0, y.max(), n_rays)

# Estimates CFL
# Computes grid smallest spacing
dd = np.min([np.diff(x).mean(), np.diff(y).mean()])
# Computes group velocity
cg = group_velocity(k0)
# Computes CFL
cfl = dd/cg

duration = round(x.max()/cg)
step_size = cfl

In [None]:
bando = mantaray.ray_tracing(x0, y0, Kx0, Ky0, duration, step_size, bathymetry, current)

In [None]:
bando

###  Plot rays and current

In [None]:
X = ds.x
Y = ds.y
U = (ds.u**2 + ds.v**2)**.5
plt.figure(figsize=(12, 6))
cs = plt.pcolormesh(X, Y, U)
for i in range(bando.ray.size)[::2]:
    ray = bando.isel(ray=i)
    plt.plot(ray.x, ray.y, 'k', lw=.78)
plt.colorbar(cs)