# Calculate the frequency-domain FFT of atmosphere temperature spectra

### Compute over three regions:

- full atmosphere domain
- portion of atmosphere that lies over the ocean
- portion of atmosphere that lies over land

In [1]:
%matplotlib inline

import os
from glob import glob
import numpy as np
import matplotlib.pyplot as plt

from detrend_func import detrend_func
from window_func import main as window_func

import dask
from distributed import Client, LocalCluster
import dask.array as da

import xarray as xr

In [2]:
# Create a Dask Distributed cluster with one processor and as many
# threads as cores
cluster = LocalCluster(n_workers=1)
client = Client(cluster)

### Define constants to be used

In [3]:
dx = 80000. # meters
dy = 80000. # meters
dt = 1 # in days
H = [2000.0, 3000.0, 4000.0]  # meters
Htot = H[0] + H[1] + H[2]
Hm = 1000. # meters
f0 = 9.37456*(10**(-5)) #1/s (Coriolis parameter)
g = [1.2, .4]
rho = 1.0 # kg/(m^3)
Cp = 1000. # specific heat of atm in J/(kg*K)
T1 = 287
ocnorm = 1.0/(960*960) #from QGCM write-up
atnorm = 1/(384*96)
dt = 1 #86400.0 # seconds in a day (time difference between each point in time)
K2 = 2.5e4 # m^2/s (called st2d in input.params)
K4 = 2e14 # m^4/s (called st4d in input.params)
tmbaro = 3.04870e2 # mean ocean mixed layer absolute temperature (K)
tmbara = 3.05522e2 # mean atmosphere mixed layer absolute temperature (K)
toc1 = -1.78699e1 # Relative temperature for ocean layer 1

In [4]:
tile_size = 100

## Temperature Spectrum (frequency domain)

$$
\left| \widehat{ T(x,y,\omega)}  \right| ^2
$$

# Calculate spectra over all windows, full domain

In [8]:
def calc_SD(dsx, tile_index):
    
    # Select variable
    T = dsx.ast

    # Select specified tile, with 2-cell padding for taking of derivatives
    ix, iy = tile_index
    print('ix,iy',ix,iy)
    T = T.isel(yt=slice(max(iy*tile_size,0), (iy+1)*tile_size),
               xt=slice(max(ix*tile_size,0), (ix+1)*tile_size))
    xp = T.xt.values # grab values of dask arrays
    yp = T.yt.values
    time = T.time.values
    
    T = T.data
    print('T shape with buffer of 2 on either side = ',T.shape)

    # Update size variables
    ny, nx, nt = T.shape
    print('after adding two cells for boundaries ',T.shape)
    T = T.rechunk(chunks={0: ny, 1: nx})

    # Use smaller spatial chunks for fft
    T = T.rechunk(chunks={0: tile_size/10, 1: tile_size/10, 2: nt})
    
    print('T shape after chunking/10 = ',T.shape)
        
    # Function that detrends, windows, and takes fft
    def fft_block(var):
        var = detrend_func(var,'time')
        var = window_func(var,'time')
        print('fft function')
        varhat = (1./var.shape[2])*np.fft.rfft(var, axis=2)
        return varhat
    
    # Execute fft function above
    That = T.map_blocks(fft_block)
    print('T shape after fft = ',That.shape)
    
    # Resize back to tile_size chunks
    That = That.rechunk(chunks={0: tile_size, 1:tile_size, 2: 365})

    # Multiply together
    #SD = 0.5 * (1./(f0**2)) * (That.conj()*That) #.real) <-- needed?
    SD = (That.conj()*That).real
    
    print('final shape = ',SD.shape)

    # Sum over x- and y-axes
    SD = da.sum(SD, axis=(0,1))
    
    SD = SD.compute()
    
    # Wrap as DataArray
    n = len(time)
    d = time[1] - time[0]
    freq = np.fft.rfftfreq(n, d)
    
    SD = xr.DataArray(SD, 
                      dims=['freq'], 
                      coords={'freq': freq,
                              'xp': xp.mean(),
                              'yp': yp.mean()})
    
    return SD

In [9]:
%%time
from itertools import product

datadir = '/g/data/hh5/tmp/pm2987/Spunup_ycexp_atast/'
save_datapath = '/g/data/v45/pm2987/Power_spectra/Tvar/'
#datadir = '/g/data/v45/pm2987/Spunup_ocsst/'

# ceil(x/y) = (x+y+1)//y
yi = range((384+tile_size+1)//(tile_size)) 
xi = range((96+tile_size+1)//(tile_size))

tile_indexes = list(product(*[yi, xi]))

KE1_sum = 0
KE2_sum = 0
KE3_sum = 0
buoyancy_sum = 0
PE1_sum = 0
PE2_sum = 0
windstress_sum = 0
bottomDrag_sum = 0

# All files
ncfiles = sorted(glob(os.path.join(datadir, 'output*/atast.nc'))) # for all files

for i in np.arange(3,7):
    print('i = ',i)
    # Select desired files
    ncfiles_loop = ncfiles[(i*50):(i*50+100)] # 0 starts with year 233
    print(ncfiles_loop[0],ncfiles_loop[-1])
    
    chunks = {'xt': tile_size, 'yt': tile_size,'time':365}
    datasets = [xr.open_dataset(fn,chunks=chunks) for fn in ncfiles_loop]
    dsx = xr.concat(datasets, dim='time', coords='all')
    
    spec3_w_sum = 0
    spec2_w_sum = 0
    spec1_w_sum = 0
    
    for tile_index in tile_indexes[:]:
        print(tile_index)
        spec1_w = calc_SD(dsx, tile_index)
        spec1_w_sum += spec1_w
    
    np.save(save_datapath+'T_atm_ycexp_spectrum_'+str(i*50+233)+'_'+str(i*50+332),spec1_w_sum)
    
    print('Done with round ',i)

i =  3
/g/data/hh5/tmp/pm2987/Spunup_ycexp_atast/output383/atast.nc /g/data/hh5/tmp/pm2987/Spunup_ycexp_atast/output482/atast.nc
(0, 0)
ix,iy 0 0
T shape with buffer of 2 on either side =  (96, 100, 36500)
after adding two cells for boundaries  (96, 100, 36500)
T shape after chunking/10 =  (96, 100, 36500)
fft function
T shape after fft =  (96, 100, 36500)
final shape =  (96, 100, 36500)
(1, 0)
ix,iy 1 0
T shape with buffer of 2 on either side =  (96, 100, 36500)
after adding two cells for boundaries  (96, 100, 36500)
T shape after chunking/10 =  (96, 100, 36500)
fft function
T shape after fft =  (96, 100, 36500)
final shape =  (96, 100, 36500)
(2, 0)
ix,iy 2 0
T shape with buffer of 2 on either side =  (96, 100, 36500)
after adding two cells for boundaries  (96, 100, 36500)
T shape after chunking/10 =  (96, 100, 36500)
fft function
T shape after fft =  (96, 100, 36500)
final shape =  (96, 100, 36500)
(3, 0)
ix,iy 3 0
T shape with buffer of 2 on either side =  (96, 84, 36500)
after add

# Calculate over ocean spectrum

In [6]:
def calc_SD_overOcean(dsx, tile_index):
    
    # Select variable
    T = dsx.ast

    # Select specified tile, with 2-cell padding for taking of derivatives
    ix, iy = tile_index
    print('ix,iy',ix,iy)
    T = T.isel(yt=slice(max(iy*tile_size,0), (iy+1)*tile_size),
               xt=slice(max(ix*tile_size,0), (ix+1)*tile_size))
    
    if ix == 1:
        print('ix==1')
        T = T.isel(yt=slice(18,-18),xt=slice(62,None))
        print('T.shape = ',T.shape)
  
    elif ix == 2:
        T = T.isel(yt=slice(18,-18),xt=slice(None,22))
        print('T.shape = ',T.shape)  
    
    xp = T.xt.values # grab values of dask arrays
    yp = T.yt.values
    time = T.time.values
    
    T = T.data
    print('T shape with buffer of 2 on either side = ',T.shape)

    # Update size variables
    ny, nx, nt = T.shape
    print('after adding two cells for boundaries ',T.shape)
    T = T.rechunk(chunks={0: ny, 1: nx})

    # Use smaller spatial chunks for fft
    T = T.rechunk(chunks={0: tile_size/10, 1: tile_size/10, 2: nt})
    
    print('T shape after chunking/10 = ',T.shape)
        
    # Function that detrends, windows, and takes fft
    def fft_block(var):
        var = detrend_func(var,'time')
        var = window_func(var,'time')
        print('fft function')
        varhat = (1./var.shape[2])*np.fft.rfft(var, axis=2)
        return varhat
    
    # Execute fft function above
    That = T.map_blocks(fft_block)
    print('T shape after fft = ',That.shape)
    
    # Resize back to tile_size chunks
    That = That.rechunk(chunks={0: tile_size, 1:tile_size, 2: 365})

    # Multiply together
    #SD = 0.5 * (1./(f0**2)) * (That.conj()*That) #.real) <-- needed?
    SD = (That.conj()*That).real
    
    print('final shape = ',SD.shape)

    # Sum over x- and y-axes
    SD = da.sum(SD, axis=(0,1))
    
    SD = SD.compute()
    
    # Wrap as DataArray
    n = len(time)
    d = time[1] - time[0]
    freq = np.fft.rfftfreq(n, d)
    
    SD = xr.DataArray(SD, 
                      dims=['freq'], 
                      coords={'freq': freq,
                              'xp': xp.mean(),
                              'yp': yp.mean()})
    
    return SD

In [7]:
%%time
from itertools import product

datadir = '/g/data/hh5/tmp/pm2987/Spunup_ycexp_atast/'
save_datapath = '/g/data/v45/pm2987/Power_spectra/Tvar/'

# ceil(x/y) = (x+y+1)//y
yi = range(1,3) 
xi = range((96+tile_size+1)//(tile_size))

tile_indexes = list(product(*[yi, xi]))

KE1_sum = 0
KE2_sum = 0
KE3_sum = 0
buoyancy_sum = 0
PE1_sum = 0
PE2_sum = 0
windstress_sum = 0
bottomDrag_sum = 0

# All files
ncfiles = sorted(glob(os.path.join(datadir, 'output*/atast.nc'))) # for all files

for i in np.arange(3,7):
    print('i = ',i)
    # Select desired files
    ncfiles_loop = ncfiles[(i*50):(i*50+100)] # 0 starts with year 233
    print(ncfiles_loop[0],ncfiles_loop[-1])
    
    chunks = {'xt': tile_size, 'yt': tile_size,'time':365}
    datasets = [xr.open_dataset(fn,chunks=chunks) for fn in ncfiles_loop]
    dsx = xr.concat(datasets, dim='time', coords='all')
    
    spec3_w_sum = 0
    spec2_w_sum = 0
    spec1_w_sum = 0
    
    for tile_index in tile_indexes[:]:
        print(tile_index)
        spec1_w = calc_SD_overOcean(dsx, tile_index)
        spec1_w_sum += spec1_w
    
    np.save(save_datapath+'T_atm_ycexp_overOcean_spectrum_'+str(i*50+233)+'_'+str(i*50+332),spec1_w_sum)
    
    print('Done with round ',i)

i =  3
/g/data/hh5/tmp/pm2987/Spunup_ycexp_atast/output383/atast.nc /g/data/hh5/tmp/pm2987/Spunup_ycexp_atast/output482/atast.nc
(1, 0)
ix,iy 1 0
ix==1
T.shape =  (60, 38, 36500)
T shape with buffer of 2 on either side =  (60, 38, 36500)
after adding two cells for boundaries  (60, 38, 36500)
T shape after chunking/10 =  (60, 38, 36500)
fft function
T shape after fft =  (60, 38, 36500)
final shape =  (60, 38, 36500)
(2, 0)
ix,iy 2 0
T.shape =  (60, 22, 36500)
T shape with buffer of 2 on either side =  (60, 22, 36500)
after adding two cells for boundaries  (60, 22, 36500)
T shape after chunking/10 =  (60, 22, 36500)
fft function
T shape after fft =  (60, 22, 36500)
final shape =  (60, 22, 36500)
Done with round  3
i =  4
/g/data/hh5/tmp/pm2987/Spunup_ycexp_atast/output433/atast.nc /g/data/hh5/tmp/pm2987/Spunup_ycexp_atast/output532/atast.nc
(1, 0)
ix,iy 1 0
ix==1
T.shape =  (60, 38, 36500)
T shape with buffer of 2 on either side =  (60, 38, 36500)
after adding two cells for boundaries  (

save_datapath = '/g/data/v45/pm2987/Power_spectra/Tvar/'
np.save(save_datapath+'T_atm_overOcean_spectrum_34_133.npy',spec1_w_sum)

# Calculate spectrum over land

In [9]:
def calc_SD_overland(dsx, tile_index):
    
    # Select variable
    T = dsx.ast

    # Select specified tile, with 2-cell padding for taking of derivatives
    ix, iy = tile_index
    print('ix,iy',ix,iy)
    T = T.isel(yt=slice(max(iy*tile_size,0), (iy+1)*tile_size),
               xt=slice(max(ix*tile_size,0), (ix+1)*tile_size))
    
    # Set ocean points to zero
    #  - there are 38 atm points over ocean in ix=1
    #  - here I slice the land portions into rectangles
    if ix == 1:
        print('ix==1')
        
        T_left = T.isel(xt=slice(None,62))
        T_top = T.isel(yt=slice(-18,None),xt=slice(62,None))
        T_bot = T.isel(yt=slice(None,18),xt=slice(62,None))
        
        T_zeros = T.isel(yt=slice(18,-18),xt=slice(62,None)) # grab only the ocean points
        T_zeros = T_zeros.where(T_zeros>10000,0) # set all ocean points to zero: if T_zeros>10000 is false, fill with zeros
        print(T_zeros.isel(xt=slice(4,5),yt=slice(3,4),time=slice(3,4)))
        print('T_left,T_top,T_bot,T_zeros shapes = ',T_left.shape,T_top.shape,T_bot.shape,T_zeros.shape)
        
        # Concatenate right side of chunk
        T_right = xr.concat([T_top,T_zeros,T_bot],'yt')
        print(T_right.shape)
        
        # Concatenate entire chunk together
        T = xr.concat([T_left,T_right],'xt')
        print('T.shape = ',T.shape)
        
    # Set ocean points to zero
    #  - there are 22 atm points over ocean in ix=2
    elif ix == 2:
        
        T_right = T.isel(xt=slice(22,None))
        T_top = T.isel(yt=slice(-18,None),xt=slice(None,22))
        T_bot = T.isel(yt=slice(None,18),xt=slice(None,22))
        
        T_zeros = T.isel(yt=slice(18,-18),xt=slice(None,22))
        T_zeros = T_zeros.where(T_zeros>10000,0)
        print(T_zeros.isel(xt=slice(None,5),yt=slice(3,4),time=slice(3,4)))
        print('T_right,T_top,T_bot,T_zeros shapes = ',T_right.shape,T_top.shape,T_bot.shape,T_zeros.shape)
        
        T_left = xr.concat([T_top,T_zeros,T_bot],'yt')
        print('T_left.shape = ',T_left.shape)
        
        T = xr.concat([T_left,T_right],'xt')
        print('T.shape = ',T.shape)
    
    
    xp = T.xt.values # grab values of dask arrays
    yp = T.yt.values
    time = T.time.values
    
    T = T.data
    print('T shape with buffer of 2 on either side = ',T.shape)

    # Update size variables
    ny, nx, nt = T.shape
    print('after adding two cells for boundaries ',T.shape)
    T = T.rechunk(chunks={0: ny, 1: nx})

    # Use smaller spatial chunks for fft
    T = T.rechunk(chunks={0: tile_size/10, 1: tile_size/10, 2: nt})
    
    print('T shape after chunking/10 = ',T.shape)
        
    # Function that detrends, windows, and takes fft
    def fft_block(var):
        var = detrend_func(var,'time')
        var = window_func(var,'time')
        print('fft function')
        varhat = (1./var.shape[2])*np.fft.rfft(var, axis=2)
        return varhat
    
    # Execute fft function above
    That = T.map_blocks(fft_block)
    print('T shape after fft = ',That.shape)
    
    # Resize back to tile_size chunks
    That = That.rechunk(chunks={0: tile_size, 1:tile_size, 2: 365})

    # Multiply together
    #SD = 0.5 * (1./(f0**2)) * (That.conj()*That) #.real) <-- needed?
    SD = (That.conj()*That).real
    
    print('final shape = ',SD.shape)

    # Sum over x- and y-axes
    SD = da.sum(SD, axis=(0,1))
    
    SD = SD.compute()
    
    # Wrap as DataArray
    n = len(time)
    d = time[1] - time[0]
    freq = np.fft.rfftfreq(n, d)
    
    SD = xr.DataArray(SD, 
                      dims=['freq'], 
                      coords={'freq': freq,
                              'xp': xp.mean(),
                              'yp': yp.mean()})
    
    return SD

In [11]:
%%time
from itertools import product

datadir = '/g/data/v45/pm2987/Spunup_atast/'
save_datapath = '/g/data/v45/pm2987/Power_spectra/Tvar/'
#datadir = '/g/data/v45/pm2987/Spunup_ocsst/'

# ceil(x/y) = (x+y+1)//y
yi = range((384+tile_size+1)//(tile_size)) 
xi = range((96+tile_size+1)//(tile_size))

tile_indexes = list(product(*[yi, xi]))

KE1_sum = 0
KE2_sum = 0
KE3_sum = 0
buoyancy_sum = 0
PE1_sum = 0
PE2_sum = 0
windstress_sum = 0
bottomDrag_sum = 0

# All files
ncfiles = sorted(glob(os.path.join(datadir, 'output*/atast.nc'))) # for all files

for i in np.arange(3,7):
    print('i = ',i)
    # Select desired files
    ncfiles_loop = ncfiles[(i*50):(i*50+100)] # 0 starts with year 233
    print(ncfiles_loop[0],ncfiles_loop[-1])
    
    chunks = {'xt': tile_size, 'yt': tile_size,'time':365}
    datasets = [xr.open_dataset(fn,chunks=chunks) for fn in ncfiles_loop]
    dsx = xr.concat(datasets, dim='time', coords='all')
    
    spec3_w_sum = 0
    spec2_w_sum = 0
    spec1_w_sum = 0
    
    for tile_index in tile_indexes[:]:
        print(tile_index)
        spec1_w = calc_SD_overland(dsx, tile_index)
        spec1_w_sum += spec1_w
    
    np.save(save_datapath+'T_atm_overLand_spectrum_'+str(i*50+34)+'_'+str(i*50+133),spec1_w_sum)
    
    print('Done with round ',i)

i =  3
/g/data/v45/pm2987/Spunup_atast/output184/atast.nc /g/data/v45/pm2987/Spunup_atast/output283/atast.nc
(0, 0)
ix,iy 0 0
T shape with buffer of 2 on either side =  (96, 100, 36500)
after adding two cells for boundaries  (96, 100, 36500)
T shape after chunking/10 =  (96, 100, 36500)
fft function
T shape after fft =  (96, 100, 36500)
final shape =  (96, 100, 36500)
(1, 0)
ix,iy 1 0
ix==1
<xarray.DataArray 'ast' (yt: 1, xt: 1, time: 1)>
dask.array<getitem, shape=(1, 1, 1), dtype=float32, chunksize=(1, 1, 1), chunktype=numpy.ndarray>
Coordinates:
  * xt       (xt) float32 13320.0
  * yt       (yt) float32 1720.0
  * time     (time) float64 223.0
Attributes:
    units:      K
    long_name:  Atmosphere surface temperature
T_left,T_top,T_bot,T_zeros shapes =  (96, 62, 36500) (18, 38, 36500) (18, 38, 36500) (60, 38, 36500)
(96, 38, 36500)
T.shape =  (96, 100, 36500)
T shape with buffer of 2 on either side =  (96, 100, 36500)
after adding two cells for boundaries  (96, 100, 36500)
T shape