In [None]:
%matplotlib inline

# system modules
import os
import time
from pathlib import Path
from IPython.display import display, Math

# scientific computing
import numpy as np
from numpy import linalg as LA
import pandas as pd
np.random.seed(42)
from fipy import CellVariable,PeriodicGrid2D, Grid2D
from fipy import DiffusionTerm, ExponentialConvectionTerm, DefaultAsymmetricSolver, ImplicitSourceTerm
from fipy import MatplotlibStreamViewer
from fipy.tools.numerix import array, reshape

# plotting
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from plotly.offline import iplot, init_notebook_mode
import plotly.graph_objs as go
import plotly.io as pio
from plotly import subplots
init_notebook_mode(connected=True)

# pytorch importing
import torch
import torch.nn as nn
from torchvision import datasets
from torch.optim import lr_scheduler, Adam
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
torch.manual_seed(42)

### Taylor Green Vortex

In [None]:
def vortex():
    vortex_u = np.zeros([Ny,Nx])
    vortex_v = np.zeros([Ny,Nx])
    for i in range(0,Ny):
        for j in range(0,Nx):
            vortex_u[i][j] = V_0*np.sin((2*np.pi/Lx)*(j+0.5)*dx)*np.cos((2*np.pi/Ly)*(Ny-0.5-i)*dy)
            vortex_v[i][j] = -1*V_0*np.cos((2*np.pi/Lx)*(j+0.5)*dx)*np.sin((2*np.pi/Ly)*(Ny-0.5-i)*dy)
            
    vortex_u = vortex_u.flatten()
    vortex_v = vortex_v.flatten()
    
    return vortex_u,vortex_v

### Velocity Function

In [None]:
def sample_vel(mesh):
    vel = mesh.cellCenters.copy()
    vel.name = 'Velocity'
    u,v = vortex()
    vel.value[0] = u
    vel.value[1] = v
    return vel

### Strain Field

In [None]:
def strain_rate_mag():
    strain = np.zeros([Ny,Nx])
    for i in range(0,Ny):
        for j in range(0,Nx):
            strain[i][j] = np.abs(4*np.pi*V_0/Lx*np.cos((2*np.pi/Lx)*(j+0.5)*dx)*np.cos((2*np.pi/Ly)*(Ny-0.5-i)*dy))

    return strain/(4*np.pi)

### Spin Field

In [None]:
def rotation_mag():
    omega = np.zeros([Ny,Nx])
    for i in range(0,Ny):
        for j in range(0,Nx):
            omega[i][j] = np.abs(4*np.pi*V_0/Lx*np.sin((2*np.pi/Lx)*(j+0.5)*dx)*np.sin((2*np.pi/Ly)*(Ny-0.5-i)*dy))
            
    return omega

### Strain Function

In [None]:
def sample_stra(mesh):
    stra = CellVariable(name='Strain', mesh=mesh)   
    strain_matrix = strain_rate_mag()
    stra.value = strain_matrix.flatten()
    
    return stra

### Rotation Function

In [None]:
def sample_rota(mesh):
    rota = CellVariable(name='Rotation', mesh=mesh)   
    rota_matrix = rotation_mag()
    rota.value = rota_matrix.flatten()
    
    return rota

### Source Function

In [None]:
def compute_src(mesh):
    src = CellVariable(name='Production', mesh=mesh)
    
    strain = strain_rate_mag()
    rota = rotation_mag()
    src_matrix = src.value.reshape(Ny,Nx)
    src_matrix = (1/(1+rota**2))*(4*np.sin(2*np.pi*strain)+6*strain**2+5*np.exp(strain)) # production function
    src.value = src_matrix.flatten()
    
    return src 

### Data_Concentrations Generation Function

In [None]:
def solve_cdr_pde(mesh, diff_coeff, diss_coeff):
    data_list = []
    array_list = []
    var = CellVariable(name='Variable',mesh=mesh)
    vel = sample_vel(mesh)
    stra = sample_stra(mesh)
    rota = sample_rota(mesh)
    src = compute_src(mesh)
    eq = - ExponentialConvectionTerm(coeff=vel) + DiffusionTerm(coeff=diff_coeff) - ImplicitSourceTerm(diss_coeff) + src
    eq.solve(var=var, solver=DefaultAsymmetricSolver(tolerance=1.e-12, iterations=10000))
    data = {'var': var, 'src': src, 'vel': vel, 'stra': stra, 'rota': rota, 'diff': diff_coeff, 'diss': diss_coeff}
    return data

### Data Storage

In [None]:
def periodic(array):  # dimension of array => [Ny,Nx]
    peri_num_x = 2*(int(N*h/Lx)+1)+1
    peri_num_y = 2*(int(N*h/Ly)+1)+1
    peri = np.zeros((peri_num_y*Ny,peri_num_x*Nx))
    for i in range(0,Ny*(peri_num_y-1)+1,Ny):
        for j in range(0,Nx*(peri_num_x-1)+1,Nx):
            peri[i:i+Ny,j:j+Nx] = array
    return peri

### Neural Network Model

In [None]:
class CNN_Network(nn.Module):
    def __init__(self):
        super(CNN_Network, self).__init__()
        
        self.relu = nn.ReLU()
        
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=3, stride=1, padding=1)
        
        self.fc1 = nn.Linear(2,32)
        self.fc2 = nn.Linear(32,16)
        self.fc3 = nn.Linear(16,8)
        self.fc4 = nn.Linear(8,4)
        self.fc5 = nn.Linear(4,1)
        
        
        
    def forward(self, X):
        X1 = X[:,0:2,:,:]
        X2 = X[:,2:3,:,:]
        X3 = X[:,3:4,:,:]
        
        out = self.relu(self.conv1(X1))
        out = self.relu(self.conv2(out))
        out = self.relu(self.conv3(out))
        out = self.relu(self.conv4(out))
        out = self.relu(self.conv5(out))
        G = self.conv6(out)
        
        dim1 = X.size()[0]
        dim3 = X.size()[2]
        
        X2 = X2.reshape(-1,1)
        X3 = X3.reshape(-1,1)
        X4 = torch.cat((X2,X3),1)
        
        X4_out = self.relu(self.fc1(X4))
        X4_out = self.relu(self.fc2(X4_out))
        X4_out = self.relu(self.fc3(X4_out))
        X4_out = self.relu(self.fc4(X4_out))
        X4_out = self.fc5(X4_out)
        P = X4_out.reshape(dim1,-1,dim3,dim3)
        
        Init1 = torch.zeros([dim1,1,dim3,dim3],device=device)
        Init2 = torch.addcmul(Init1,1, G, P)
        c = torch.sum(torch.sum(Init2,dim=3),dim=2)
        c = c.reshape(dim1,1)
        
        return G,P,c

### Training Function

In [None]:
def train(train_loader, valid_loader, num_epoch):
    train_err_hist = torch.cuda.FloatTensor(1,1).fill_(0)
    valid_err_hist = torch.cuda.FloatTensor(1,1).fill_(0)
    train_loss_hist = torch.cuda.FloatTensor(1,1).fill_(0)
    valid_loss_hist = torch.cuda.FloatTensor(1,1).fill_(0)

    for epoch in range(num_epoch+1):
        train_loss_array = torch.cuda.FloatTensor(1,1).fill_(0)
        train_err_rate_num = torch.cuda.FloatTensor(1,1).fill_(0)
        train_err_rate_den = torch.cuda.FloatTensor(1,1).fill_(0)
        valid_loss_array = torch.cuda.FloatTensor(1,1).fill_(0)
        valid_err_rate_num = torch.cuda.FloatTensor(1,1).fill_(0)
        valid_err_rate_den = torch.cuda.FloatTensor(1,1).fill_(0)

        for i, data in enumerate(train_loader):
            features, target = data
            optimizer.zero_grad()
            G_train, P_train, forward = model(features)
            loss = loss_fn(forward, target)
            loss.backward()
            optimizer.step()

            train_loss_array = torch.cat((train_loss_array, torch.cuda.FloatTensor([[loss.item()]])))
            train_err_num, train_err_den = report_err_rate(target, forward)
            train_err_rate_num = torch.cat((train_err_rate_num, (train_err_num.view(1,-1))**2), 0)
            train_err_rate_den = torch.cat((train_err_rate_den, (train_err_den.view(1,-1))**2), 0)

        train_loss = torch.mean(train_loss_array)
        train_err_rate = 100*((torch.sum(train_err_rate_num, 0))**0.5)/((torch.sum(train_err_rate_den, 0))**0.5)

        exp_lr_scheduler.step()

        with torch.no_grad():
            for i, data_valid in enumerate(valid_loader):
                features_valid, target_valid = data_valid
                G_valid, P_valid, forward_valid = model(features_valid)
                pred_loss = loss_fn(forward_valid, target_valid)

                valid_loss_array = torch.cat((valid_loss_array, torch.cuda.FloatTensor([[loss.item()]])))
                valid_err_num, valid_err_den = report_err_rate(target_valid, forward_valid)
                valid_err_rate_num = torch.cat((valid_err_rate_num, (valid_err_num.view(1,-1))**2), 0)
                valid_err_rate_den = torch.cat((valid_err_rate_den, (valid_err_den.view(1,-1))**2), 0)

            valid_loss = torch.mean(valid_loss_array)
            valid_err_rate = 100*((torch.sum(valid_err_rate_num, 0))**0.5)/((torch.sum(valid_err_rate_den, 0))**0.5)

        verb = True if (epoch >= 50) and (epoch % 10 == 0) else False
        if (verb):
            train_loss_hist = torch.cat((train_loss_hist, torch.cuda.FloatTensor([[train_loss]])))
            train_err_hist = torch.cat((train_err_hist, train_err_rate.view(1,-1)), 0)
            valid_loss_hist = torch.cat((valid_loss_hist, torch.cuda.FloatTensor([[valid_loss]])))
            valid_err_hist = torch.cat((valid_err_hist, valid_err_rate.view(1,-1)), 0)
        verb = True if (epoch % 50 == 0) else False
        if (verb) :
            print('{:4}   lr: {:.2e}   train_loss: {:.2e}   valid_loss: {:.2e}   train_error:{:7.2f}%   valid_error:{:7.2f}%' \
                  .format(epoch, exp_lr_scheduler.get_lr()[0], train_loss, valid_loss, train_err_rate[0], valid_err_rate[0]))
            
    print('Finished Training')
    return train_loss_hist, train_err_hist, valid_loss_hist, valid_err_hist

In [None]:
def report_err_rate(target, forward):
    errRate_sigma_num = torch.norm(forward - target, dim = 0)
    errRate_sigma_den = torch.norm(target, dim = 0)
    return errRate_sigma_num, errRate_sigma_den

### Plotting Functions

In [None]:
def error_plot(training_loss_history, training_error_history, valid_loss_history, valid_error_history):
    
    data1 = go.Scatter(x=np.arange(50,num_epoch+1,10), y=training_loss_history[1:,0], line = dict(width=1.7), name = 'Training Loss', mode = 'lines')
    data2 = go.Scatter(x=np.arange(50,num_epoch+1,10), y=valid_loss_history[1:,0], line = dict(width=1.7), name = 'Validation Loss', mode = 'lines')
    data3 = go.Scatter(x=np.arange(50,num_epoch+1,10), y=training_error_history[1:,0], line = dict(width=1.7), name = 'Training Error', mode = 'lines')
    data4 = go.Scatter(x=np.arange(50,num_epoch+1,10), y=valid_error_history[1:,0], line = dict(width=1.7), name = 'Validation Error', mode = 'lines')
    
    fig = subplots.make_subplots(rows=1, cols=2)
    fig.append_trace(data1, 1, 1)
    fig.append_trace(data2, 1, 1)
    fig.append_trace(data3, 1, 2)
    fig.append_trace(data4, 1, 2)
    
    fig['layout']['xaxis1'].update(title='Epochs', showgrid=True, gridwidth=0.5, gridcolor='lightgrey', showline=True, linecolor='black')
    fig['layout']['yaxis1'].update(title='Loss', showgrid=True, gridwidth=0.5, gridcolor='lightgrey', showline=True, linecolor='black')
    fig['layout']['xaxis2'].update(title='Epochs', showgrid=True, gridwidth=0.5, gridcolor='lightgrey', showline=True, linecolor='black')
    fig['layout']['yaxis2'].update(title='Error %', showgrid=True, gridwidth=0.5, gridcolor='lightgrey', showline=True, linecolor='black')
    fig['layout'].update(height=450, width=1000, plot_bgcolor = 'rgba(0,0,0,0)', title='Loss and Error Percentage History')
    iplot(fig)

In [None]:
def plot_valid(nn_target,nn_output, title):

    data1 = go.Scatter(x=nn_target[:,0].numpy(), y=nn_output[:,0].numpy(), mode='markers', 
                       marker=dict(color='rgb(158, 22, 25)', size=7, opacity=0.3,
                                     line=dict(width=1)), showlegend=True , name='NN')
    
    line_min = np.array([nn_target[:,0].numpy().min(), nn_output[:,0].numpy().min()]).min()
    line_max = np.array([nn_target[:,0].numpy().max(), nn_output[:,0].numpy().max()]).max()
    data2 = go.Scatter(x=np.linspace(line_min,line_max,10), y=np.linspace(line_min,line_max, 10), mode='lines', 
                       line=dict(width=1.7, color='black'), showlegend=False)

    layout = go.Layout(title=title, 
                       xaxis=dict(title='Truth', range=[line_min, line_max], showgrid=False, showline=True, linecolor='black', zeroline=False, mirror='ticks'),
                       yaxis=dict(title='Prediction', range=[line_min, line_max], showgrid=False, showline=True, linecolor='black', zeroline=False, mirror='ticks'),
                       width=600, height=570, plot_bgcolor = 'rgba(0,0,0,0)')
    
    fig = go.Figure(data=[data1,data2], layout=layout)
    iplot(fig)

In [None]:
def view_cell(var):
    fig, axes = plt.subplots(1, 1, figsize=(4, 4))
    axes.set_title('{}'.format(var.name))
    axes.set_xlabel('x')
    axes.set_ylabel('y')
    cmap = matplotlib.cm.jet
    xmin, ymin = var.mesh.extents['min']
    xmax, ymax = var.mesh.extents['max']
    data = reshape(array(var), var.mesh.shape[::-1])[::-1]
    img = axes.imshow(data, extent=(xmin, xmax, ymin, ymax), cmap=cmap)
    plt.colorbar(img)

### Domain Specification

In [None]:
if torch.cuda.is_available():
    device=torch.device('cuda:1')
else:
    device=torch.device('cpu')

device_cpu = torch.device('cpu')

In [None]:
# source field = calculation domain
dx = 0.01; dy = 0.01
Nx = 100; Ny = 100
Lx = dx*Nx; Ly = dy*Ny
mesh = PeriodicGrid2D(dx=dx, dy=dy, nx=Nx, ny=Ny)

In [None]:
epsilon=2e-1
n=3
N=1.5 # N = max(D/2h)
cells=np.array([[-5,-4,-3,-2,-1,0,1,2,3,4,5]])
cells_x=np.repeat(cells,cells.shape[1],axis=0)
cells_y=np.repeat(cells.T,cells.shape[1],axis=1)

## Case 3a: Training and Validation data from different flow cases

### Data Generation - Validation

In [None]:
num_samples_valid=5


for sample in range(num_samples_valid):
    # PDE coefficients
#     diff_coeff = np.random.uniform(low=0.01, high=0.1)     # Diffusion Coefficient
#     diss_coeff = np.random.uniform(low=10, high=15)     # Dissipation Coefficient
    
    diff_coeff = 0.02
    diss_coeff = 20
    
    C_max = 1.0  # the maximum value of the velocity
    V_0 = np.random.uniform(low=C_max*0.05, high=C_max)

    
    data_3_valid = solve_cdr_pde(mesh, diff_coeff, diss_coeff)
    
    lambda1=(C_max - (((C_max**2)+(4*diff_coeff*diss_coeff))**0.5)) / (2*diff_coeff)
    h=np.abs(np.log(epsilon)/lambda1)   # h => maximum h
    
    print('diff     diss     h      V_0')
    print('{:.3f}   {:.3f}   {:.3f}   {:.3f}'.format(diff_coeff,diss_coeff,h,V_0))
    view_cell(data_3_valid['stra'])
    view_cell(data_3_valid['rota'])
    view_cell(data_3_valid['src'])
    view_cell(data_3_valid['var'])
    
    c_data_valid_mesh_3 = reshape(array(data_3_valid['var']), data_3_valid['var'].mesh.shape[::-1])
    s_data_valid_mesh_3 = reshape(array(data_3_valid['stra']), data_3_valid['stra'].mesh.shape[::-1])
    r_data_valid_mesh_3 = reshape(array(data_3_valid['rota']), data_3_valid['rota'].mesh.shape[::-1])
    u_data_valid_mesh_3 = reshape(array(data_3_valid['vel'])[0,:], data_3_valid['vel'].mesh.shape[::-1])
    v_data_valid_mesh_3 = reshape(array(data_3_valid['vel'])[1,:], data_3_valid['vel'].mesh.shape[::-1])
    
    # add local noise to c field
#     for i in range(0,Ny):
#         for j in range(0,Nx):
#             c_data_valid_mesh_3[i][j] += np.random.normal(loc=0.0,scale=noise_level*c_data_valid_mesh_3[i][j])
    
#     print(c_data_valid_mesh_3.shape)
#     print(type(u_data_valid_mesh_3))
    
    periodic_c_valid = periodic(c_data_valid_mesh_3)
    periodic_s_valid = periodic(s_data_valid_mesh_3)
    periodic_r_valid = periodic(r_data_valid_mesh_3)
    periodic_u_valid = periodic(u_data_valid_mesh_3)
    periodic_v_valid = periodic(v_data_valid_mesh_3)
    
#     print(periodic_c_valid.shape)
    
    dataX_valid_current=np.empty([c_data_valid_mesh_3.shape[0]*c_data_valid_mesh_3.shape[1], 4, cells.shape[1], cells.shape[1]])
    dataY_valid_current=np.empty([c_data_valid_mesh_3.shape[0]*c_data_valid_mesh_3.shape[1], 1])
    
#     print(dataX_valid_current.shape)
#     print(dataY_valid_current.shape)
    
    data_ind=0
    for i in range(Ny*(int(N*h/Ly)+1),Ny*(int(N*h/Ly)+2)):
        for j in range(Nx*(int(N*h/Lx)+1),Nx*(int(N*h/Lx)+2)):
            dataX_valid_current[data_ind,0,:,:]=periodic_u_valid[(cells_y*n)+i,(cells_x*n)+j]
            dataX_valid_current[data_ind,1,:,:]=periodic_v_valid[(cells_y*n)+i,(cells_x*n)+j]
            dataX_valid_current[data_ind,2,:,:]=periodic_s_valid[(cells_y*n)+i,(cells_x*n)+j]
            dataX_valid_current[data_ind,3,:,:]=periodic_r_valid[(cells_y*n)+i,(cells_x*n)+j]
            dataY_valid_current[data_ind,0]=periodic_c_valid[i,j]
            data_ind+=1
    
    dataX_valid_current=torch.tensor(dataX_valid_current).to(dtype=torch.float)
    dataY_valid_current=torch.tensor(dataY_valid_current).to(dtype=torch.float)
    
    if sample<1:
        dataX_valid=dataX_valid_current
        dataY_valid=dataY_valid_current
    else:
        dataX_valid=torch.cat((dataX_valid,dataX_valid_current))
        dataY_valid=torch.cat((dataY_valid,dataY_valid_current))

### Data Generation - Training

In [None]:
num_samples=100


for sample in range(num_samples):
    # PDE coefficients
#     diff_coeff = np.random.uniform(low=0.01, high=0.1)     # Diffusion Coefficient
#     diss_coeff = np.random.uniform(low=10, high=15)     # Dissipation Coefficient
    
    diff_coeff = 0.02
    diss_coeff = 20
    
    C_max = 1.0  # the maximum value of the velocity
    V_0 = np.random.uniform(low=C_max*0.05, high=C_max)   # velocity range (0,1)
    
    data_3 = solve_cdr_pde(mesh, diff_coeff, diss_coeff)
    
    lambda1=(C_max - (((C_max**2)+(4*diff_coeff*diss_coeff))**0.5)) / (2*diff_coeff)
    h=np.abs(np.log(epsilon)/lambda1)   # h => maximum h
    
    c_data_mesh_3 = reshape(array(data_3['var']), data_3['var'].mesh.shape[::-1])
    s_data_mesh_3 = reshape(array(data_3['stra']), data_3['stra'].mesh.shape[::-1])
    r_data_mesh_3 = reshape(array(data_3['rota']), data_3['rota'].mesh.shape[::-1])
    u_data_mesh_3 = reshape(array(data_3['vel'])[0,:], data_3['vel'].mesh.shape[::-1])
    v_data_mesh_3 = reshape(array(data_3['vel'])[1,:], data_3['vel'].mesh.shape[::-1])
    
    
    periodic_c = periodic(c_data_mesh_3)
    periodic_s = periodic(s_data_mesh_3)
    periodic_r = periodic(r_data_mesh_3)
    periodic_u = periodic(u_data_mesh_3)
    periodic_v = periodic(v_data_mesh_3)
    
    dataX_train_current=np.empty([c_data_mesh_3.shape[0]*c_data_mesh_3.shape[1], 4, cells.shape[1], cells.shape[1]])
    dataY_train_current=np.empty([c_data_mesh_3.shape[0]*c_data_mesh_3.shape[1], 1])
    

    data_ind=0
    for i in range(Ny*(int(N*h/Ly)+1),Ny*(int(N*h/Ly)+2)):
        for j in range(Nx*(int(N*h/Lx)+1),Nx*(int(N*h/Lx)+2)):
            dataX_train_current[data_ind,0,:,:]=periodic_u[(cells_y*n)+i,(cells_x*n)+j]
            dataX_train_current[data_ind,1,:,:]=periodic_v[(cells_y*n)+i,(cells_x*n)+j]
            dataX_train_current[data_ind,2,:,:]=periodic_s[(cells_y*n)+i,(cells_x*n)+j]
            dataX_train_current[data_ind,3,:,:]=periodic_r[(cells_y*n)+i,(cells_x*n)+j]
            dataY_train_current[data_ind,0]=periodic_c[i,j]
            data_ind+=1
    
    dataX_train_current=torch.tensor(dataX_train_current).to(dtype=torch.float)
    dataY_train_current=torch.tensor(dataY_train_current).to(dtype=torch.float)
    
    if sample<1:
        dataX_train=dataX_train_current
        dataY_train=dataY_train_current
    else:
        dataX_train=torch.cat((dataX_train,dataX_train_current))
        dataY_train=torch.cat((dataY_train,dataY_train_current))
        
    print(sample)

In [None]:
reduc_data_size = int(dataX_train.shape[0] * (50/num_samples))
ind = list(range(dataX_train.shape[0]))
np.random.shuffle(ind)
train_ind = ind[:reduc_data_size]
dataX_train = dataX_train[train_ind]
dataY_train = dataY_train[train_ind]
print('Reduced Training Data Size: {}   {}'.format(dataX_train.shape, dataY_train.shape))
print('Validation Data Size:       {}   {}'.format(dataX_valid.shape, dataY_valid.shape))

In [None]:
dataX_train = dataX_train.to(device)
dataY_train = dataY_train.to(device)
dataX_valid = dataX_valid.to(device)
dataY_valid = dataY_valid.to(device)

In [None]:
#creating datasets
dataset_train = TensorDataset(dataX_train,dataY_train)
dataset_valid = TensorDataset(dataX_valid,dataY_valid)

#creating batches from dataset
batch_size_train = 1024          #int(dataX_train.shape[0]/20) + 1
batch_size_valid = dataX_valid.shape[0]

train_loader = DataLoader(dataset = dataset_train, batch_size=batch_size_train, shuffle=True)
valid_loader = DataLoader(dataset = dataset_valid, batch_size=batch_size_valid, shuffle=False)

In [None]:
np.random.seed(7)
model = CNN_Network()
model.to(device)
loss_fn = nn.MSELoss(reduction='sum')

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
para_count = count_parameters(model)
print('Total Learnable Parameters: {}'.format(para_count))

In [None]:
# training
num_epoch = 500
learning_rate = 1e-3
optimizer = Adam(model.parameters(), lr=learning_rate)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=600, gamma=0.7)

In [None]:
start_time = time.time()
training_loss_history, training_error_history, valid_loss_history, valid_error_history = train(train_loader, valid_loader, num_epoch)
elapsed = time.time() - start_time                
print('Training time: %.1f s' % (elapsed))

In [None]:
torch.save(model, 'TG_learning-P_gpu.pt')  # 1-1 --> num of case_fixed/unfixed
model.to(device_cpu)
torch.save(model, 'TG_learning-P_cpu.pt')

In [None]:
training_loss_history = training_loss_history.to(device_cpu)
training_error_history = training_error_history.to(device_cpu)
valid_loss_history = valid_loss_history.to(device_cpu)
valid_error_history = valid_error_history.to(device_cpu)
error_plot(training_loss_history.detach().numpy(), training_error_history.detach().numpy(), valid_loss_history.detach().numpy(), valid_error_history.detach().numpy())

In [None]:
# with torch.no_grad():
#     for i, nn_data in enumerate(valid_loader):
#         nn_features, nn_target = nn_data
#         G_output, nn_output = model(nn_features)

# nn_output = nn_output.to(device_cpu)
# nn_target = nn_target.to(device_cpu)
# title = 'Validation'
# plot_valid(nn_target,nn_output,title)