## Initial set-up ##

In [None]:
# Initial imports - some redundancies may remain
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import History
from tensorflow.keras.optimizers import Adam
import math
from tensorflow import py_function, double
import requests
import gc
import json
import random
from sklearn.metrics import mean_squared_error, mean_absolute_error
from tensorflow.keras.callbacks import History
import netCDF4 as nc


# Set up some important numbers
lastLoss = 100000
numFeatures = 100*7*4
days = 7
batchSize = 512
numEpochs = 30
numTestingFiles = 50
# maxY = 40
# minY = 0

# Pseudorandom seeding for consistency
seed = 6
np.random.seed(seed)

# Set up some important numbers
numTrainingFiles = 137
numRecordsPerFile = 100000
numTestingFiles = 50
latDim = 101
longDim = 101
imageSize = latDim*longDim
minLat = -12
maxLat = -10
minLon = 145
maxLon = 147

# Save file prefix
savePrefix = 'LSTMLite'

# Checkpoint
print("Checkpoint: initialization complete")

### Preprocessing and custom parameter set up

In [None]:
class PreprocessingLayer(Layer):
    def __init__(self, **kwargs):
        super(PreprocessingLayer, self).__init__(**kwargs) 

    def build(self, input_shape):
        self.temporalPenalty = self.add_weight(name='temporalPenalty',
                                    initializer=tf.keras.initializers.RandomUniform(minval=50, maxval=100),
                                    shape=(1,),
                                    trainable=True)

        super(PreprocessingLayer, self).build(input_shape)

    #create the layer operation here
    def call(self, x):
        # Reshape and normalise the data
        x = keras.layers.Reshape((700, 4))(x)
        x = x / [100, 100, 3, 40]
        
        # Calculate the distances using the custom sorting parameter
        distances = tf.math.sqrt(tf.math.add_n([tf.math.square(x[:,:,0]), tf.math.square(x[:,:,1]), tf.math.multiply(self.temporalPenalty, tf.math.square(x[:,:,2]))]))
       
        # Collect the temperatures
        temperatures = x[:, :, 3]
                    
        # Sort the data by distance
        sortOrder = tf.argsort(distances, direction='ASCENDING')
        sortOrder = keras.layers.Reshape((700,))(sortOrder)
        sortedDistances = tf.gather(distances, sortOrder, batch_dims=1, axis=-1)
        sortedTemperatures = tf.gather(temperatures, sortOrder, batch_dims=1, axis=-1)

        # Prepare the features
        features = tf.reshape(tf.stack([sortedTemperatures, sortedDistances], axis=2),  (-1, 1400, 1))
        features = features[:, 0:200, :]
        features = tf.convert_to_tensor(features)
        return features

### Build the model and load in weights ###

In [None]:
def create_model():
    # Preprocessing step
    x_input = tf.keras.layers.Input(shape=(numFeatures, )) # or numFeatures // 2 for only temperatures
    features = keras.layers.Reshape((700, 4))(x_input)
    features = PreprocessingLayer()(features)
    
    # LSTM step
    x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64))(features)
    x = tf.keras.layers.Dense(1, activation='linear')(x)
    model = tf.keras.models.Model(inputs = x_input, outputs = x, name = "LSTMLite64")

    # Set up model parameters
    opt = keras.optimizers.Adam(learning_rate=0.0001)
    model.compile(loss='mae', optimizer=opt, metrics=['mae', 'mse'])
    return model

# Create the model - can load in weights if this is a transfer learning run
model = create_model()
model.summary()
model_name = "LSTMLite64"

In [None]:
## Check temporal weight
model.load_weights("Models/LSTMLite64v2.h5")
print(model.trainable_weights[0])

### Load in files ###

In [None]:
## Clean up
gc.collect()

# Load in the model weights
model.load_weights #(Path to model goes here)

# Create arrays for populating
y_true_no_nan = np.empty(0)
y_true = np.empty(0)
pixel_locations = np.empty(0)
y_pred_no_nan = np.empty(0)
y_pred = np.empty(0)

# Loop through all .npy files
for i in range(0, 9):
    print("Up to file " + str(i))
    
    # Load the first testing tiles - note that paths may need to be changed!
    X_test_no_nan = np.load("Data/X_test_AllN_" + str(i) + ".npy", mmap_mode='r')
    y_test_no_nan = np.load("Data/Y_test_AllN_" + str(i) + ".npy", mmap_mode='r')
    X_test = np.load("Data/X_test_with_nan_AllN_" + str(i) + ".npy", mmap_mode='r')
    Y_test = np.load("Data/Y_test_with_nan_AllN_" + str(i) + ".npy", mmap_mode='r')
    locations = np.load("Data/Position_test_AllN_" + str(i) + ".npy", mmap_mode='r')
    print("I am loaded") # Checkpoint
    
    # Append the y values to the growing arrays above
    y_pred = np.append(y_pred, model.predict(X_test)*40)
    print("I have predicted") # Checkpoint
    y_pred_no_nan = np.append(y_pred_no_nan, model.predict(X_test_no_nan)*40)
    y_true = np.append(y_true, Y_test)
    y_true_no_nan = np.append(y_true_no_nan, y_test_no_nan)
    
    # Check error so far
    print("Error so far: " + str(np.mean(np.absolute(y_pred_no_nan-y_true_no_nan))))  
    
    # Append the positional values
    if i == 0:
        pixel_locations = locations
    else:
        pixel_locations = np.append(pixel_locations, locations, axis=0)
        print(locations)
        print(pixel_locations)
    
    # Purge variables
    del X_test, Y_test, X_test_no_nan, y_test_no_nan, locations
    gc.collect()

print("Checkpoint: Data loaded")

### Fill the images ###

In [None]:
testingInputNc = # Path to "TestingDataFrom2020.nc" goes here
testingOutputNc = # Path to "TestingTargetsFrom2020.nc" goes here
savePrefix = # Save path prefix goes here

# Read in the testing images
imageDataset = nc.Dataset(testingInputNc)
targetDataset = nc.Dataset(testingOutputNc)
sst = imageDataset['sst'][:]
sstTarget = targetDataset['sst'][:]

# Read in times
timeFilename = # Path to "TestingTimes2020.txt" goes here
timeFile = open(timeFilename)
timeFileLines = timeFile.readlines()
times = [int(s) for s in timeFileLines[0].split(',')]
clouds = [int(s) for s in timeFileLines[1].split(',')]

print("Checkpoint: about to make some images")

for imageIndex in range(0, len(times)):

    # Find target index
    targetLoc = np.nonzero(np.array(times)==times[imageIndex])
    firstOccurance = int(targetLoc[0][0])
    targetIndex = int(np.floor(firstOccurance/5))
    print(targetIndex)

    # Info about the image
    testTime = times[imageIndex]
    testCloud = clouds[imageIndex]

    # Get up cloud image and save it in same format
    testImage = sst[imageIndex, :, :]
    np.savetxt(savePrefix + 'Day' + str(testTime) + 'Cloud' + str(testCloud) + '_Cloud.txt', testImage)

    # Get the target image
    targetImage = sstTarget[targetIndex, :, :]

    # Image to be filled
    testFilledImage = sst[imageIndex, :, :]

    # Find image start
    startIndex = 0
    stopIndex = 0
    for i in range(0, len(pixel_locations)):
        if all(pixel_locations[i, 2:] == np.array([testTime, testCloud])):
            startIndex = i
            stopIndex = startIndex + imageSize
            if stopIndex > len(y_pred):
                stopIndex = len(y_pred) - 1
                
            break

    # Loop through every pixel in each image
    for x in range(0, longDim):
        for y in range(0, latDim):
            if math.isnan(testFilledImage[x, y]):
                pixelLoc = np.array([x+1, y+1, testTime, testCloud]) # Adjustment in indices because of Matlab v Python
                
                for i in range(startIndex, stopIndex):
                    if all(pixel_locations[i, :] == pixelLoc[:]):
                        testFilledImage[x, y] = y_pred[i]

    np.savetxt(savePrefix + 'Day' + str(testTime) + 'Cloud' + str(testCloud) + '_Target.txt', targetImage)
    np.savetxt(savePrefix + 'Day' + str(testTime) + 'Cloud' + str(testCloud) + '_Filled.txt', testFilledImage)