In [7]:
%matplotlib inline
import numpy as np
import os, sys
import cv2
import matplotlib.pyplot as plt
import xarray as xr
import pandas as pd
from torch.utils import data

sys.path.append(os.path.dirname(os.path.abspath(os.getcwd())))
from data import goes16s3
from tools import inference_tools

In [18]:
n_channels = 3
channels = range(1,n_channels+1)
n_minutes = 15


noaadata = goes16s3.NOAAGOESS3(product='ABI-L1b-RadM', channels=channels,
                               save_directory='/raid/tj/GOES/TEST/')
                               #save_directory='/mnt/nexai-goes/GOES/S3/')
#inference_dir = '/raid/tj/GOES/SloMo/%iChannel-%iminute-Inference-Test' % (len(channels), n_minutes)
inference_dir = '/raid/tj/GOES/SloMo/%iChannel-%iminute-Inference-Hurricane' % (len(channels), n_minutes)

if not os.path.exists(inference_dir):
    os.makedirs(inference_dir)

In [19]:
localfiles = noaadata.local_files()#year=2019)
print(localfiles)
days = localfiles.index.unique(level='dayofyear')


                                                                                        file  \
channel                                                                                    1   
year dayofyear hour minute second spatial                                                      
2018 63        12   0      272    RadM1    /raid/tj/GOES/TEST/ABI-L1b-RadM/2018/063/12/00...   
                           572    RadM2    /raid/tj/GOES/TEST/ABI-L1b-RadM/2018/063/12/00...   
                    1      272    RadM1    /raid/tj/GOES/TEST/ABI-L1b-RadM/2018/063/12/01...   
                           572    RadM2    /raid/tj/GOES/TEST/ABI-L1b-RadM/2018/063/12/01...   
                    2      272    RadM1    /raid/tj/GOES/TEST/ABI-L1b-RadM/2018/063/12/02...   
                           572    RadM2    /raid/tj/GOES/TEST/ABI-L1b-RadM/2018/063/12/02...   
                    3      272    RadM1    /raid/tj/GOES/TEST/ABI-L1b-RadM/2018/063/12/03...   
                           572    RadM2 

In [20]:
checkpoint_sv = '../saved-models/9Min-%iChannels-LambdaW_0.10-LambdaS_0.10-Batch20' % len(channels)
checkpoint_mv = '../saved-models/9Min-%iChannels-LambdaW_0.10-LambdaS_0.10-Batch20_MV2/' % len(channels)

flownetsv, interpnetsv, warpersv = inference_tools.load_models(n_channels, checkpoint_sv, 
                                                         multivariate=False)
flownetmv, interpnetmv, warpermv = inference_tools.load_models(n_channels, checkpoint_mv, 
                                                         multivariate=True)


loading checkpoint ../saved-models/9Min-3Channels-LambdaW_0.10-LambdaS_0.10-Batch20/checkpoint.flownet.pth.tar
=> loaded checkpoint '../saved-models/9Min-3Channels-LambdaW_0.10-LambdaS_0.10-Batch20/checkpoint.flownet.pth.tar' (epoch 50)
loading checkpoint ../saved-models/9Min-3Channels-LambdaW_0.10-LambdaS_0.10-Batch20_MV2/checkpoint.flownet.mv.pth.tar
=> loaded checkpoint '../saved-models/9Min-3Channels-LambdaW_0.10-LambdaS_0.10-Batch20_MV2/checkpoint.flownet.mv.pth.tar' (epoch 50)


In [21]:
def _linear_interpolation(X0, X1, t):
    diff = X1 - X0
    return X0 + t * diff

def linear_interpolation(X0, X1, ts):
    pframes = [X0]
    for t in ts:
        _linear_interpolation(X0, X1, t)
        pframes.append(_linear_interpolation(X0, X1, t))
    pframes.append(X1)
    pframes = [frame.values[np.newaxis] for frame in pframes]
    pframes = inference_tools.block_predictions_to_dataarray(pframes, X0)
    return pframes

def time_rmse(x1, x2):
    diff = np.square(x1 - x2)
    return diff.mean(['x', 'y',])**0.5



In [22]:
year = 2018

for day in [281]:#days:
    iterator = noaadata.iterate_day(year, day, max_queue_size=n_minutes+1, min_queue_size=1)

    #ts = np.linspace(0.1, 0.9, 9)
    ts = np.linspace(1./n_minutes, 1-1./n_minutes, n_minutes-1 )

    for i, example in enumerate(iterator):
        fpath = os.path.join(inference_dir, '%4i_%03i_Example-%03i.nc' % (year, day, i))
        if os.path.exists(fpath):
            #os.remove(fpath)
            continue
        X0 = example.isel(t=0)
        X1 = example.isel(t=-1)
        ressv = inference_tools.inference(X0, X1, flownetsv, interpnetsv, warpersv, 
                                          multivariate=False, T=n_minutes-1)    

        resmv = inference_tools.inference(X0, X1, flownetmv, interpnetmv, warpermv, 
                                          multivariate=True,  T=n_minutes-1)   

        linearres = linear_interpolation(X0, X1, ts)

        ds = xr.Dataset({'slomo_sv': ressv, 'slomo_mv': resmv, 'linear': linearres, 'observed': example})
        ds['sv_rmse'] = time_rmse(ressv, example)
        ds['mv_rmse'] = time_rmse(resmv, example)
        ds['linear_rmse'] = time_rmse(linearres, example)


        ds.to_netcdf(fpath)
        print(fpath)


/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-000.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-001.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-002.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-003.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-004.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-005.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-006.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-007.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-008.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-009.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-010.nc
/raid/tj/GOES/SloMo/3Channel-15minute-Inference-Hurricane/2018_281_Example-011.nc
/raid/tj/GOES/Sl

In [None]:
print(sorted(os.listdir(inference_dir)))