In [1]:
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
torch.cuda.empty_cache()
import time
import matplotlib.pyplot as plt
import numpy as np
import os

seed = 1234
np.random.seed(seed)
torch.manual_seed(seed)

# # ------- UNCOMMENT IF USING COLAB ----
# # Mount google drive
# from google.colab import drive
# drive.mount('/content/drive')

# # Update path to import from Drive
# import sys
# dir_path = '/content/drive/MyDrive/HybridFilters/'
# sys.path.append(dir_path) # path to folder in drive

# device = 'cuda:0'
# # ------------------------------------

dir_path = os.getcwd()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
## --------------- LOAD DATA ----------------
# Load numpy data
N=128

project_label = 'topHat8_valLaxSod50' # update to match the training data used

filename_training = dir_path + '/training_data/disc_data_norm_siac0_syntheticTrainTophat8_testLaxSod484_valLaxSod50_LaxSodVal.npz'
filename_validation = filename_training
filename_test = filename_training 

# Training dataset
data_file_training = np.load(filename_training)
train_siac = data_file_training['disc_train']
train_siac_exact = data_file_training['exact_train']

# Validation dataset
data_file_val = np.load(filename_validation)
val_siac = data_file_val['disc_val']
val_siac_exact = data_file_val['exact_val']

# Test dataset
data_file_test = np.load(filename_test)
test_siac = data_file_test['disc_test']
test_siac_exact = data_file_test['exact_test']

# Selecting datasets
train = train_siac
train_exact = train_siac_exact
val = val_siac
val_exact = val_siac_exact
test = test_siac
test_exact = test_siac_exact
plt_label = 'SIAC'

# Convert to Torch tensor datasets
batch = 200
x_train = torch.tensor(train)
y_train = torch.tensor(train_exact)
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)

# x_val = torch.tensor(val)
# y_val = torch.tensor(val_exact)

# x_test = torch.tensor(test)
# y_test = torch.tensor(test_exact)


In [None]:
################ FILTER #########################
## ---------------DEFINE & TRAIN FILTER ANN ----------------

kernel = 7
hidden_channels = 128
num_layers = 7
grid_length = 5000
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.leaky_relu_1= torch.nn.LeakyReLU(negative_slope=0.1)

        self.conv_in = torch.nn.Conv1d(1, hidden_channels, kernel_size = kernel, padding='same', padding_mode='replicate')

        self.conv2 = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size = kernel, padding='same', padding_mode='replicate')
        self.conv3 = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size = kernel, padding='same', padding_mode='replicate')
        self.conv4 = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size = kernel, padding='same', padding_mode='replicate')
        self.conv5 = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size = kernel, padding='same', padding_mode='replicate')
        self.conv6 = torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size = kernel, padding='same', padding_mode='replicate')

        self.conv_out = torch.nn.Conv1d(hidden_channels, 1, kernel_size = kernel, padding='same', padding_mode='replicate')

    def resnet(self, x):

      x = self.leaky_relu_1(self.conv_in(x))

      x = self.leaky_relu_1(self.conv2(x))
      x = self.leaky_relu_1(self.conv3(x))
      x = self.leaky_relu_1(self.conv4(x))
      x = self.leaky_relu_1(self.conv5(x))
      x = self.leaky_relu_1(self.conv6(x))

      x = self.leaky_relu_1(self.conv_out(x))

      return x

    def forward(self, x):

      # Apply Data-driven, nonlinear NN filter with consistency
      x_nn = x + self.resnet(x)

      ones_vec = torch.ones(1,1,grid_length, device=device)
      consistency = ones_vec + self.resnet(ones_vec)
      constant = sum(consistency.view(-1))/grid_length

      x_nn = x_nn/constant

      return x_nn


net = Net()
net.to(device)

pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'Number of trainable parameters: {pytorch_total_params}')

In [None]:
# Vector for Consistency condition
ones_vec = torch.ones(1,1,grid_length, device=device)
half_window = 4
pts_per_cell = 4
mesh = np.linspace(-half_window, half_window+1, (2*half_window+1)*pts_per_cell)


# Define Loss Function and Optimization method
criterion_l2 = torch.nn.MSELoss()
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# Saving Model
dir = dir_path + '/' + project_label + '/'
if not os.path.exists(dir):
    os.makedirs(dir)

filename = dir + "B" + str(batch) + "_L" + str(num_layers)+ "_Ch" + str(hidden_channels)+ "_K" + str(kernel) + "_LR" + "{:5.1e}".format(learning_rate)


## ---------------TRAIN MODEL ----------------
epochs = 500
plot_loss_freq = 1
training_loss_data = torch.zeros(epochs)
validation_loss = torch.zeros(epochs)
test_loss = torch.zeros(epochs)
consistency = torch.zeros(epochs)
best_epochs = []

# *************UPDATE to previous best Validation Loss if training from Previous Model
best_loss = np.inf

# torch.cuda.empty_cache()

for epoch in range(epochs):
    running_loss_data = 0.0
    running_loss_moment = 0.0

    if epoch ==0:
        unfiltered_loss = 0.0

    start_time = time.time()

    net.train()
    for x_batch, y_batch in train_dl:

        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        # Zero out the gradients
        optimizer.zero_grad()

        # Run forward step and compute loss
        pred = net(x_batch.float())
        true = y_batch.float()
        loss = criterion_l2(pred, true)

        # Run backpropagation and update the weights
        loss.backward()
        optimizer.step()

        running_loss_data += loss.detach().item()

        if epoch == 0:
            unfiltered = criterion_l2(x_batch.float(), y_batch.float())
            unfiltered_loss += unfiltered.detach().item()

        del x_batch
        del y_batch

    end_time = time.time()
    epoch_time = end_time - start_time

    training_loss_data[epoch] = running_loss_data/len(train_dl)
    consistency[epoch] = (sum(net(ones_vec).view(-1))/grid_length).detach().cpu()
    unfiltered_plot = unfiltered_loss/len(train_dl)

    net.eval()
    loss_valid = 0.0
    with torch.no_grad():

        val_input = torch.FloatTensor(val).to(device)
        val_output = torch.FloatTensor(val_exact).to(device)
        filtered_val = net(val_input)
        validation_loss[epoch] = criterion_l2(filtered_val, val_output).detach().item()

        test_input = torch.FloatTensor(test).to(device)
        test_output = torch.FloatTensor(test_exact).to(device)
        filtered_test = net(test_input)
        test_loss[epoch] = criterion_l2(filtered_test, test_output).detach().item()

        if validation_loss[epoch] < best_loss:
            best_loss = validation_loss[epoch]
            torch.save(net.state_dict(), filename)

            # Training Sample Figure
            sample_id = 0
            sample_input = train[sample_id, 0, :]
            sample_exact = train_exact[sample_id, 0, :]
            input = torch.FloatTensor(sample_input).view(1,1,-1).to(device)
            filtered_input = net(input).detach().cpu()
            filtered_input.detach().numpy()

            plt.figure()
            plt.plot(mesh, sample_input, '-s', label=plt_label)
            plt.plot(mesh, filtered_input[0,0,:], '-o', label='NN Filter * '+plt_label)
            plt.plot(mesh, sample_exact, '-k', label='Exact')
            plt.title('Training Data (current best test loss model) at epoch'+str(epoch))
            plt.xlabel('x')
            plt.ylabel('Approx.')
            plt.legend()
            plt.savefig(filename + '_trainapprox.png')
            plt.close()

            # Validation Data Accuracy and Figure
            plt.figure()
            filtered_val = net(val_input).detach().cpu()
            filtered_val.detach().numpy()
            val_output = torch.FloatTensor(val_exact).to("cpu")
            val_output.detach().numpy()
            val_id = 1
            plt.plot(mesh, val[val_id, 0, :], '-s', label=plt_label)
            plt.plot(mesh, filtered_val[val_id, 0, :], '-o', label='NN Filter * '+plt_label)
            plt.plot(mesh, val_output[val_id, 0, :], '-k', label='Exact')
            plt.title('Validation Data (current best val. loss model) at epoch'+str(epoch))
            plt.xlabel('x')
            plt.ylabel('Approx.')
            plt.legend()
            plt.savefig(filename + '_valapprox.png')
            plt.close()

            # Test Data Accuracy and Figure
            plt.figure()
            filtered_test = net(test_input).detach().cpu()
            filtered_test.detach().numpy()
            test_output = torch.FloatTensor(test_exact).to("cpu")
            test_output.detach().numpy()
            test_id = 1
            plt.plot(mesh, test[test_id, 0, :], '-s', label=plt_label)
            plt.plot(mesh, filtered_test[test_id, 0, :], '-o', label='NN Filter * '+plt_label)
            plt.plot(mesh, test_output[test_id, 0, :], '-k', label='Exact')
            plt.title('Test Data (current best val. loss model) at epoch'+str(epoch))
            plt.xlabel('x')
            plt.ylabel('Approx.')
            plt.legend()
            plt.savefig(filename + '_testapprox.png')
            plt.close()


            print('epoch: ', epoch, '| unfiltered: ', "{:5.2e}".format(unfiltered_loss/len(train_dl)), '| train loss: ', "{:5.2e}".format(training_loss_data[epoch]),
                  '| consistency: ', "{:5.6e}".format(consistency[epoch]), '| epoch_time = ', epoch_time, '| test loss: ', "{:5.2e}".format(test_loss[epoch]), '| validation loss: ', "{:5.2e}".format(validation_loss[epoch]), ' ***')
            best_epochs.append(epoch)

        else:
            print('epoch: ', epoch, '| unfiltered: ', "{:5.2e}".format(unfiltered_loss/len(train_dl)), '| train loss: ', "{:5.2e}".format(training_loss_data[epoch]),
                   '| consistency: ', "{:5.6e}".format(consistency[epoch]), '| epoch_time = ', epoch_time, '| test loss: ', "{:5.2e}".format(test_loss[epoch]), '| validation loss: ', "{:5.2e}".format(validation_loss[epoch]))

        # Loss Figure
        plt.figure()
        plt.semilogy(plot_loss_freq*np.arange(epoch +1), unfiltered_plot*np.ones(epoch +1), '-', label=plt_label)
        plt.semilogy(plot_loss_freq*np.arange(epoch +1), training_loss_data[:epoch+1], '-v', label='NN*'+plt_label+'- Training')
        plt.semilogy(plot_loss_freq*np.arange(epoch + 1), validation_loss[:epoch+1], '-o', label='NN*'+plt_label+'- Validation')
        plt.semilogy(plot_loss_freq*np.arange(epoch + 1), test_loss[:epoch+1], '-*', label='NN*'+plt_label+'- Test')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(filename + '_loss.png')
        plt.close()

        np.savez(filename, training=training_loss_data.detach().numpy(), validation=validation_loss.detach().numpy(), test=test_loss.detach().numpy(), epochs=best_epochs, consistency=consistency.detach().numpy())

