In [50]:
%matplotlib inline
import os, sys
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
import torch
import torchvision
import cv2
import flownet as fl
import goes16s3
import utils
from datetime import datetime
import xarray as xr
import scipy
import math
from skimage.measure import compare_ssim

font = {'family' : 'times',
        'size'   : 16}

matplotlib.rc('font', **font)



# Figures

##  1: Dataset Examples

##  2: Network Architecture

Network architecture as shown in slomo, including the multi-variate case. <br>
Flow and Intermediate Interpolation Networks. Can we use example input and flow images to represent it? I think so.

##  3: Optical flows of a test example
1. I0, I1, and It
2. Flows: F_01, F_10, F_01_delta, F_10_delta
3. Visible: V0 and V1
4. Difference between I0 and I1

## Figure 4: Time dependent errors
1. Linear interpolation
2. SloMo
3. MV-SloMo

# Tables
## 1: Overall and Per band errors
1, 3, 5, and 8 band experiments

In [55]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_models(n_channels, model_path, multivariate=False):
    
    if multivariate:
        model_filename = os.path.join(model_path, 'checkpoint.flownet.mv.pth.tar')
        flownet = fl.SloMoFlowNetMV(n_channels)#.cuda()
        interpnet = fl.SloMoInterpNetMV(n_channels)#.cuda()
    else:
        model_filename = os.path.join(model_path, 'checkpoint.flownet.pth.tar')
        flownet = fl.SloMoFlowNet(n_channels)#.cuda()
        interpnet = fl.SloMoInterpNet(n_channels)#.cuda()
        
    warper = fl.FlowWarper()
    
    flownet = flownet.to(device)
    interpnet = interpnet.to(device)
    warper = warper.to(device)

    def load_checkpoint(flownet, interpnet):
        epoch = 0
        if os.path.isfile(model_filename):
            print("loading checkpoint %s" % model_filename)
            checkpoint = torch.load(model_filename)
            flownet.load_state_dict(checkpoint['flownet_state_dict'])
            interpnet.load_state_dict(checkpoint['interpnet_state_dict'])
            epoch = checkpoint['epoch']
            print("=> loaded checkpoint '{}' (epoch {})"
                    .format(model_filename, epoch))
        else:
            print("=> no checkpoint found at '{}'".format(model_filename))
        return flownet, interpnet

    flownet.train()
    interpnet.train()
    flownet, interpnet = load_checkpoint(flownet, interpnet)
    return flownet, interpnet, warper

In [60]:
year = 2018
# Noreaster

days = [#datetime(year, 3, 3).timetuple().tm_yday, # Noreaster
        datetime(year, 7, 18).timetuple().tm_yday,]

channels = [1,3,5,8]



day = days[0]
n_channels = channels[0]
multivariate = True

if n_channels == 1:
    multivariate = False
    
model_path = './saved-models/5Min-%iChannels/' % n_channels

flownet, interpnet, warper = load_models(n_channels,model_path, multivariate)

loading checkpoint ./saved-models/5Min-1Channels/checkpoint.flownet.pth.tar
=> loaded checkpoint './saved-models/5Min-1Channels/checkpoint.flownet.pth.tar' (epoch 21)


In [65]:
def block_predictions_to_dataarray(predictions, block):
    block_predictions = np.concatenate(predictions, 0)
    block_predictions[block_predictions < 0] = 0
    block_predictions[block_predictions > 1] = 1

    N_pred = block_predictions.shape[0]
    da = xr.DataArray(block_predictions,#[:,:,shave:-shave,shave:-shave],
              coords=[block.t.values[:N_pred], block.band.values,
                      block.y.values,#[shave:-shave], 
                      block.x.values,],#][shave:-shave]],
              dims=['t', 'band', 'y', 'x'])

    return da

def inference(block, flownet, interpnet, warper, multivariate):
    block_vals = torch.from_numpy(block.values)
    N = block_vals.shape[0]
    idxs = np.arange(0, N+1, 5)
    preds = []
    for idx0, idx1 in zip(idxs[:-1], idxs[1:]):
        idx1 = min(N-1, idx1)
        I0 = torch.unsqueeze(block_vals[idx0], 0).to(device)
        I1 = torch.unsqueeze(block_vals[idx1], 0).to(device)

        f = flownet(I0, I1)
        n_channels = I0.shape[1]
        
        if multivariate:
            f_01 = f[:,:2*n_channels]
            f_10 = f[:,2*n_channels:]
        else:
            f_01 = f[:,:2]
            f_10 = f[:,2:]

        T = idx1 - idx0 - 1
        predicted_frames = []
        for j in range(1,T+1):
            t = 1. * j / (T+1)
            I_t, g0, g1, V_t0, V_t1, delta_f_t0, delta_f_t1 = interpnet(I0, I1, f_01, f_10, t)
            predicted_frames.append(I_t.cpu().detach().numpy())
        
        preds += [I0.cpu().numpy()] + predicted_frames
            
    return block_predictions_to_dataarray(preds, block)

def merge_and_average_dataarrays(dataarrays):
    ds = xr.merge([xr.Dataset({k: d}) for k, d in enumerate(dataarrays)])
    das = []
    for b in range(0,len(dataarrays)):
        das.append(ds[b])

    return xr.concat(das).mean('concat_dims', skipna=True)

def testset_inference(n_channels, year, day):
    model_path = './saved-models/5Min-%iChannels/' % n_channels
    flownet, interpnet, warper = load_models(n_channels, model_path, False)
    if n_channels > 1:
        flownetmv, interpnetmv, warpermv = load_models(n_channels, model_path, True)

        
    dataset = goes16s3.NOAAGOESS3(product='ABI-L1b-RadM', 
                                  channels=range(1,n_channels+1))

    saved_data_files = []
    for houri, hour_das in enumerate(dataset.read_day(year, day)): # N,12or13,512,512,3
        blocked_data = utils.blocks(hour_das, width=352)
        sv_interpolated_hour, mv_interpolated_hour = [], []
        for block_num, block in enumerate(blocked_data):            
            sv_da = inference(block, flownet, interpnet, warper, False)
            sv_interpolated_hour.append(sv_da)
            if n_channels > 1:
                mv_da = inference(block, flownetmv, interpnetmv, warpermv, True)
                mv_interpolated_hour.append(mv_da)

        sv_prediction_da = merge_and_average_dataarrays(sv_interpolated_hour)
        ds = xr.Dataset({'observed': hour_das, 'sv_predicted': sv_prediction_da})
        if n_channels > 1:
            ds['mv_predicted'] = merge_and_average_dataarrays(mv_interpolated_hour)
        ncpath = os.path.join(model_path, '%04i_%03i' % (year, day))
        if not os.path.exists(ncpath):
            os.mkdir(ncpath)
        ncfile = os.path.join(ncpath,  '%04i_%03i_%02i.nc' % (year, day, houri))

        print("Saved to file: {}".format(ncfile))
        ds.to_netcdf(ncfile)
        saved_data_files.append(ncfile)
        
#for c in channels:
for c in [1,3,5,8]:
    for day in days:
        print("Test inference for channel {}, year 2018, and day {}".format(c, day))
        testset_inference(c, 2018, day)

Test inference for channel 1, year 2018, and day 199
loading checkpoint ./saved-models/5Min-1Channels/checkpoint.flownet.pth.tar
=> loaded checkpoint './saved-models/5Min-1Channels/checkpoint.flownet.pth.tar' (epoch 21)
('Day', 199, 'Hour', 12)




Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_00.nc
('Day', 199, 'Hour', 13)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_01.nc
('Day', 199, 'Hour', 14)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_02.nc
('Day', 199, 'Hour', 15)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_03.nc
('Day', 199, 'Hour', 16)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_04.nc
('Day', 199, 'Hour', 17)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_05.nc
('Day', 199, 'Hour', 18)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_06.nc
('Day', 199, 'Hour', 19)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_07.nc
('Day', 199, 'Hour', 20)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_08.nc
('Day', 199, 'Hour', 21)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_199_09.nc
('Day', 199, 'Hour', 22)
Saved to file: ./saved-models/5Min-1Channels/2018_199/2018_1

  This is separate from the ipykernel package so we can avoid doing imports until
  after removing the cwd from sys.path.


Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_06.nc
('Day', 199, 'Hour', 19)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_07.nc
('Day', 199, 'Hour', 20)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_08.nc
('Day', 199, 'Hour', 21)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_09.nc
('Day', 199, 'Hour', 22)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_10.nc
('Day', 199, 'Hour', 23)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_11.nc
('Day', 199, 'Hour', 12)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_12.nc
('Day', 199, 'Hour', 13)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_13.nc
('Day', 199, 'Hour', 14)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_14.nc
('Day', 199, 'Hour', 15)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_199_15.nc
('Day', 199, 'Hour', 16)
Saved to file: ./saved-models/5Min-3Channels/2018_199/2018_1

In [66]:
def psnr(img1, img2, axis=None):
    img1[img1 == np.inf] = np.nan
    img2[img2 == np.inf] = np.nan
    
    mse = np.nanmean((img1 - img2) ** 2, axis=axis)
    
    #print('obs', np.histogram(img2.flatten(),
    #                range=[np.nanmin(img2), np.nanmax(img2)]))
    #if isinstance(mse, float) and mse == 0:
    #    return 100
    if np.any(mse[mse == np.inf]):
        return np.zeros_like(mse)
    
    #if isinstance(mse, float) and mse == np.inf:
    #    return None

    PIXEL_MAX = 1.0
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))


def ssim(img1, img2):
    img1[img1 == np.inf] = np.nan
    img2[img2 == np.inf] = np.nan
    
    img1[np.isnan(img1)] = 0.
    img2[np.isnan(img2)] = 0.
    r = []
    for b in range(img1.shape[1]):
        sms = [compare_ssim(img1[i,b], img2[i,b]) for i in range(img1.shape[0])]
        r.append(np.nanmean(sms))
    return np.array(r)

#def bandwise_psnr(img1, img2): # assuming (t,c,h,w)
    

In [67]:
results = []
for c in [1,3,5,8]:
    model_path = './saved-models/5Min-%iChannels/2018_199/' % c
    saved_data_files = sorted([os.path.join(model_path, f) 
                        for f in os.listdir(model_path) if f[-3:] == '.nc'])
    for sf in saved_data_files:
        print(sf)
        ds = xr.open_dataset(sf)        
        svpsnr = psnr(ds.sv_predicted.values, ds.observed.values, (3,2,0))
        svssim = ssim(ds.sv_predicted.values, ds.observed.values)
        if c > 1:
            mvpsnr = psnr(ds.mv_predicted.values, ds.observed.values, (3,2,0))
            mvssim = ssim(ds.mv_predicted.values, ds.observed.values)
        else:
            mvpsnr = [None]*len(svpsnr)
            mvssim = [None]*len(svpsnr)

        for b in range(len(svpsnr)):
            r = {'Channels': c, 'Band': b+1, 'Model': 'A-Single', 'PSNR': svpsnr[b], 'SSIM': svssim[b]}
            results.append(r)
            r = {'Channels': c, 'Band': b+1, 'Model': 'B-Multi', 'PSNR': mvpsnr[b], 'SSIM': mvssim[b]}
            results.append(r)

./saved-models/5Min-1Channels/2018_199/2018_199_00.nc
./saved-models/5Min-1Channels/2018_199/2018_199_01.nc
./saved-models/5Min-1Channels/2018_199/2018_199_02.nc
./saved-models/5Min-1Channels/2018_199/2018_199_03.nc
./saved-models/5Min-1Channels/2018_199/2018_199_04.nc
./saved-models/5Min-1Channels/2018_199/2018_199_05.nc
./saved-models/5Min-1Channels/2018_199/2018_199_06.nc
./saved-models/5Min-1Channels/2018_199/2018_199_07.nc
./saved-models/5Min-1Channels/2018_199/2018_199_08.nc
./saved-models/5Min-1Channels/2018_199/2018_199_09.nc
./saved-models/5Min-1Channels/2018_199/2018_199_10.nc
./saved-models/5Min-1Channels/2018_199/2018_199_11.nc
./saved-models/5Min-1Channels/2018_199/2018_199_12.nc
./saved-models/5Min-1Channels/2018_199/2018_199_13.nc
./saved-models/5Min-1Channels/2018_199/2018_199_14.nc
./saved-models/5Min-1Channels/2018_199/2018_199_15.nc
./saved-models/5Min-1Channels/2018_199/2018_199_16.nc
./saved-models/5Min-1Channels/2018_199/2018_199_17.nc
./saved-models/5Min-1Channel

In [68]:
results_df = pd.DataFrame(results)
results_df.set_index(["Channels", "Band", "Model"])
psnr_per_band_table = pd.pivot_table(results_df, values='PSNR', index=['Band'], 
                       columns=['Channels', 'Model'])

print(psnr_per_band_table)
reslatex = psnr_per_band_table.to_latex(float_format='{:,.2f}'.format)
print(reslatex)


results_df = pd.DataFrame(results)
results_df.set_index(["Channels", "Band", "Model"])
psnr_per_band_table = pd.pivot_table(results_df, values='PSNR', index=['Channels'], 
                       columns=['Model'])

print(psnr_per_band_table)
reslatex = psnr_per_band_table.to_latex(float_format='{:,.2f}'.format)
print(reslatex)


Channels          1          3                     5                     8  \
Model      A-Single   A-Single    B-Multi   A-Single    B-Multi   A-Single   
Band                                                                         
1         35.706924  35.088647  35.088176  34.888095  35.286151  34.395813   
2               NaN  31.464401  31.327234  31.398013  31.498074  31.075731   
3               NaN  34.257125  34.356584  34.119461  34.449316  33.674573   
4               NaN        NaN        NaN  40.777740  40.652763  40.606444   
5               NaN        NaN        NaN  31.716171  31.774259  31.323608   
6               NaN        NaN        NaN        NaN        NaN  33.676186   
7               NaN        NaN        NaN        NaN        NaN  35.723391   
8               NaN        NaN        NaN        NaN        NaN  42.826030   

Channels             
Model       B-Multi  
Band                 
1         34.703941  
2         31.047545  
3         33.927057  
4        

In [None]:
ssim_per_band_table = pd.pivot_table(results_df, values='SSIM', index=['Band'], 
                       columns=['Channels', 'Model'])

print(ssim_per_band_table)
reslatex = ssim_per_band_table.to_latex(float_format='{:,.3f}'.format)
print(reslatex)

#### Figure 1: Dataset Examples
###  1 Day - Full Disk Coverage, CONUS, and MESOSCALE


In [15]:

month = 3
day = 3
year = 2018
hour = 20
channels = [1,2,3]
#channels = [10]

day_of_year = datetime(year, month, day).timetuple().tm_yday

products = ['ABI-L1b-RadF', 'ABI-L1b-RadC', 'ABI-L1b-RadM']

arrs = []
for p in products:
    prod = goes16s3.NOAAGOESS3(product=p, 
                               channels=channels)
    prod_day_keys = prod.day_keys(year, day_of_year, hours=[hour])
    prod_key = prod_day_keys.keyname[0]
    minute = prod_day_keys.minute.min()
    prod_minute_keys = prod_day_keys[prod_day_keys.minute == minute]

    das = []
    for c in channels:
        k = prod_minute_keys[prod_minute_keys.channel == c].keyname.values[0]
        ds, _ = prod.read_nc_from_s3(k)
        da = ds.Rad
        if c == 2:
            da = utils.interp_da2d(da, 1./4, fillna=False)
        elif c in [1,3,5]:
            da = utils.interp_da2d(da, 1./2, fillna=False)
        das.append(da)

    arr = np.concatenate([d.values[np.newaxis] for d in das], axis=0)
    arrs.append(arr)

In [19]:
filenames = ['fulldisk.png', 'conus.png', 'mesoscale.png']
for i, arr in enumerate(arrs):
    plt.imsave('figures/' + filenames[i],
               arr.transpose(1,2,0)[:,:,[1,2,0]],
               dpi=50)
    

## Figure 2: Interpolation Network

In [None]:
prod = goes16s3.NOAAGOESS3(product='ABI-L1b-RadM', channels=channels)

ds = prod.read_day(year, day, hours=[hour]).next()

In [None]:
blocked_data = utils.blocks(hour_das, width=352)
B = blocked_data[0]

I0_np = B.isel(t=0).values
I0 = torch.from_numpy(I0_np[np.newaxis]).to(device)
I1_np = B.isel(t=15).values
I1 = torch.from_numpy(I1_np[np.newaxis]).to(device)


print(I0.shape, I1.shape)

f = flow_net(I0, I1)

# x, y optical flows
f_10 = f[:,:2]
f_01 = f[:,2:]

T = 4
i = 2
t = 1. * i / (T+1)
# Input Channels: predicted image and warped without derivatives
I_t, g0, g1 = interp_net(I0, I1, f_10, f_01, t)

In [None]:
obs = B.isel(t=i+1).values
pred = I_t.cpu().detach().numpy()

# Images to save
# I0
# I1 

def make_img(I, f, vmin=None, vmax=None):
    plt.imshow(I, vmin=vmin, vmax=vmax)
    plt.axis('off')
    plt.show()
    scipy.misc.imsave(f, I)
    
make_img(I0_np[[1,2,0]].transpose(1,2,0), 'figures/I0.png')
make_img(I1_np[[1,2,0]].transpose(1,2,0), 'figures/I1.png')
make_img((I1_np-I0_np)[[1,2,0]].transpose(1,2,0), 'figures/I1-minus-I0.png')

make_img(obs[[1,2,0]].transpose(1,2,0), 'figures/IT.png')


f_10_np = f_10.cpu().detach().numpy()
f_01_np = f_01.cpu().detach().numpy()

#mn = np.percentile(f_10_np[0,1], 10.)
#mx = np.percentile(f_10_np[0,1], 90.)

f_10_np = f_10_np.mean(axis=(0,1))
make_img(f_10_np[20:330,20:330], 'figures/f_10.png')#, vmin=mn, vmax=mx)

f_01_np = f_01_np.mean(axis=(0,1))
make_img(f_01_np[20:330,20:330], 'figures/f_01.png')#, vmin=mn, vmax=mx)

print(g0.shape)

## Flow Figures

In [79]:
dataset = goes16s3.NOAAGOESS3(product='ABI-L1b-RadM', 
                          channels=range(1,4))

print(datetime(2018,3,3))
hour_das = dataset.read_day(year, datetime(2018, 3, 3).timetuple().tm_yday, 
                            hours=[20]).next() # N,12or13,512,512,3
blocked_data = utils.blocks(hour_das, width=352)

2018-03-03 00:00:00
('Day', 62, 'Hour', 20)


In [80]:
def opticalflow(flow):
    hsv = np.ones((flow.shape[0], flow.shape[1], 3))*255.
    
    # Use Hue, Saturation, Value colour model 
    mag, ang = cv2.cartToPolar(flow[:,:, 0], flow[:,:, 1])
    
    hsv[:,:, 0] = ang * 180 / np.pi / 2
    hsv[:,:, 2] = cv2.normalize(mag, None, 0, 255., cv2.NORM_MINMAX)
    bgr = cv2.cvtColor(np.uint8(hsv), cv2.COLOR_HSV2BGR)
    return bgr

def flowfigures(block):
    n_channels = 3
    model_path = './saved-models/5Min-3Channels/'
    flownet, interpnet, warper = load_models(n_channels, model_path, False)
        
    block_vals = torch.from_numpy(block.values)
    N = block_vals.shape[0]
    idxs = np.arange(0, N+1, 5)
    preds = []
    
    idx0 = 0
    idx1 = 5

    j = 2
    t = 1. * j / 5
    
    I0 = torch.unsqueeze(block_vals[idx0], 0).to(device)
    I1 = torch.unsqueeze(block_vals[idx1], 0).to(device)
    IT_img = block_vals[idx0 + j].detach().cpu().numpy().transpose(1,2,0)
    
    f = flownet(I0, I1)

    f_01 = f[:,:2]
    f_10 = f[:,2:]

    I_t, g0, g1, V_t0, V_t1, delta_f_t0, delta_f_t1 = interpnet(I0, I1, f_01, f_10, t)
    F_t0_hat = -(1-t) * t * f_01 + t**2 * f_10
    F_t1_hat =  (1-t) ** 2 * f_01 - t * (1 - t) * f_10
    F_t0 = F_t0_hat + delta_f_t0
    F_t1 = F_t1_hat + delta_f_t1
    
    
    detach = lambda x: x.detach().cpu().numpy()[0].transpose(1,2,0)
    
    I0_img = detach(I0)
    I1_img = detach(I1)
    plt.imsave('figures/flowfigs/I0.png', I0_img)
    plt.imsave('figures/flowfigs/I1.png', I1_img)

    
    V_t0_img = detach(V_t0)
    plt.imsave('figures/flowfigs/V_t0.png', V_t0_img[:,:,0], cmap='gray')
    V_t1_img = detach(V_t1)
    plt.imsave('figures/flowfigs/V_t1.png', V_t1_img[:,:,0], cmap='gray')
    plt.imsave('figures/flowfigs/V_tdiff.png', V_t1_img[:,:,0] - V_t0_img[:,:,0], cmap='gray')

    f_01_img = opticalflow(detach(F_t0)[20:332,30:332])
    plt.imsave('figures/flowfigs/F_t0.png', f_01_img)
    f_10_img = opticalflow(detach(F_t1)[20:332,30:332])
    plt.imsave('figures/flowfigs/F_t1.png', f_10_img)
    
    I_t_img = detach(I_t)
    plt.imsave('figures/flowfigs/I_t.png', I_t_img)
    diff = IT_img - I_t_img
    plt.imsave('figures/flowfigs/residual.png', diff*5)
    
    
flowfigures(blocked_data[0])

loading checkpoint ./saved-models/5Min-3Channels/checkpoint.flownet.pth.tar
=> loaded checkpoint './saved-models/5Min-3Channels/checkpoint.flownet.pth.tar' (epoch 19)




In [35]:
## Average PSNR 