In [1]:
%matplotlib inline

from multiprocessing import cpu_count
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import _LRScheduler
import netCDF4
import xarray as xr

import datetime

seed = 1
np.random.seed(seed)

In [2]:
data = ['../lstm_dataset/20210730_dataset/kpp1_rmmeh/trainSet.nc',
        '../lstm_dataset/20210730_dataset/kpp2_rmmeh/trainSet.nc', 
        '../lstm_dataset/20210730_dataset/central1_rmmeh/trainSet.nc', 
        '../lstm_dataset/20210730_dataset/central2_rmmeh/trainSet.nc', 
        #'../lstm_dataset/20210730_dataset/central3_rmmeh/trainSet.nc', 
        '../lstm_dataset/20210730_dataset/north_eastern1_rmmeh/trainSet.nc', 
        '../lstm_dataset/20210730_dataset/north_eastern2_rmmeh/trainSet.nc', 
        '../lstm_dataset/20210730_dataset/north_eastern3_rmmeh/trainSet.nc', 
        '../lstm_dataset/20210730_dataset/eastern1_rmmeh/trainSet.nc', 
        '../lstm_dataset/20210730_dataset/eastern2_rmmeh/trainSet.nc'
       ]

In [3]:
'''
'coastal', 'blue', 'green', 'red', 'veg5', 'veg6', 'veg7', 'nir',
       'narrow_nir', 'water_vapour', 'swir1', 'swir2', 'SCL', 'WVP',
       'AOT'
'''

def sampling_ds(ds, idx):
    coastal = ds['coastal'].sel(location = idx, time=ds.time[:70])/10000
    blue = ds['blue'].sel(location = idx, time=ds.time[:70])/10000
    green = ds['green'].sel(location = idx, time=ds.time[:70])/10000
    red = ds['red'].sel(location = idx, time=ds.time[:70])/10000
    veg5 = ds['veg5'].sel(location = idx, time=ds.time[:70])/10000
    veg6 = ds['veg6'].sel(location = idx, time=ds.time[:70])/10000
    veg7 = ds['veg7'].sel(location = idx, time=ds.time[:70])/10000
    nir = ds['nir'].sel(location = idx, time=ds.time[:70])/10000
    narrow_nir = ds['narrow_nir'].sel(location = idx, time=ds.time[:70])/10000
    water_vapour = ds['water_vapour'].sel(location = idx, time=ds.time[:70])/10000
    swir1 = ds['swir1'].sel(location = idx, time=ds.time[:70])/10000
    swir2 = ds['swir2'].sel(location = idx, time=ds.time[:70])/10000
    SCL = ds['SCL'].sel(location = idx, time=ds.time[:70])
    WVP = ds['WVP'].sel(location = idx, time=ds.time[:70])/1000
    AOT = ds['AOT'].sel(location = idx, time=ds.time[:70])/1000

    ndvi = (nir-red)/(nir+red+.00001)
    ndwi = (green-nir)/(green+nir+.00001)

    y = ds.y.sel(location=idx)
    
    tmp_array = np.array([coastal, blue, green, red, veg5, veg6, veg7, nir, 
                           narrow_nir, water_vapour, swir1, swir2, SCL, WVP, AOT,
                           ndvi, ndwi])
    
    i = ds.row.sel(location=idx)
    j = ds.col.sel(location=idx)
    position = np.array([i,j])
    
    lat = ds.latitude.sel(location=idx)
    lon = ds.longitude.sel(location=idx)
    location = np.array([lat,lon])
    
    return (tmp_array.transpose((1,0,2)), y.values, position.transpose(), location.transpose(), ds.time.values[:70])

In [4]:
trn_x = None
trn_y = None
trn_pos = None
trn_location = None
trn_time = None

test_x = None
test_y = None
test_pos = None
test_location = None
test_time = None

for file in data:
    print('read file:', file)
    ds = xr.open_dataset(file)
    
    band = list(ds.keys())
    band.append('ndvi')
    band.append('ndwi')

    y = ds.y.values

    other_idx_trn = np.random.choice(np.where(y ==0)[0], 5000)
    sugar_idx_trn = np.random.choice(np.where(y ==1)[0], 10000)
    rice_idx_trn = np.random.choice(np.where(y ==2)[0], 5000)

    other_idx_test = np.random.choice(np.where(y ==0)[0], 5000)
    sugar_idx_test = np.random.choice(np.where(y ==1)[0], 10000)
    rice_idx_test = np.random.choice(np.where(y ==2)[0], 5000)

    trn_idx = np.concatenate([other_idx_trn,sugar_idx_trn,rice_idx_trn])
    test_idx = np.concatenate([other_idx_test,sugar_idx_test,rice_idx_test])

    (tmp_array, y, position, location, timestamp) = sampling_ds(ds, trn_idx)    
    if trn_x is None:
        trn_x = tmp_array
        trn_y = y
        trn_pos = position
        trn_location = location
        trn_time = timestamp

    #elif len(tmp_array) < max_n_sample:
    else:
        trn_x = np.append(trn_x, tmp_array, axis=0)
        trn_y = np.append(trn_y, y, axis=0)
        trn_pos = np.append(trn_pos, position, axis=0)
        trn_location = np.append(trn_location, location, axis=0)
    
    
    (tmp_array, y, position, location, timestamp) = sampling_ds(ds, test_idx)    
    if test_x is None:
        test_x = tmp_array
        test_y = y
        test_pos = position
        test_location = location
        test_time = timestamp

    #elif len(tmp_array) < max_n_sample:
    else:
        test_x = np.append(test_x, tmp_array, axis=0)
        test_y = np.append(test_y, y, axis=0)
        test_pos = np.append(test_pos, position, axis=0)
        test_location = np.append(test_location, location, axis=0)

read file: ../lstm_dataset/20210730_dataset/kpp1_rmmeh/trainSet.nc


module 'contextlib' has no attribute 'nullcontext'


read file: ../lstm_dataset/20210730_dataset/kpp2_rmmeh/trainSet.nc
read file: ../lstm_dataset/20210730_dataset/central1_rmmeh/trainSet.nc
read file: ../lstm_dataset/20210730_dataset/central2_rmmeh/trainSet.nc
read file: ../lstm_dataset/20210730_dataset/north_eastern1_rmmeh/trainSet.nc
read file: ../lstm_dataset/20210730_dataset/north_eastern2_rmmeh/trainSet.nc
read file: ../lstm_dataset/20210730_dataset/north_eastern3_rmmeh/trainSet.nc
read file: ../lstm_dataset/20210730_dataset/eastern1_rmmeh/trainSet.nc
read file: ../lstm_dataset/20210730_dataset/eastern2_rmmeh/trainSet.nc


In [5]:
print("Shape of x:", trn_x.shape)
print("Shape of y:", trn_y.shape)
print("Shape of position:", trn_pos.shape)
print("Shape of location:", trn_location.shape)
print("Shape of timstamp:", trn_time.shape)

Shape of x: (180000, 17, 70)
Shape of y: (180000,)
Shape of position: (180000, 2)
Shape of location: (180000, 2)
Shape of timstamp: (70,)


In [6]:
trn_y[trn_y == 2] = 0
test_y[test_y == 2] = 0

In [7]:
#Save to NPY format
np.savez("trainSetSugarRMMEH.npz", x=trn_x, y=trn_y, location=trn_location, position=trn_pos, timestamp = trn_time, bands=np.array(band))
np.savez("testSetSugarRMMEH.npz", x=test_x, y=test_y, location=test_location, position=test_pos, timestamp = test_time, bands=np.array(band))