In [None]:
import math
import time
import decimal
from decimal import *
import numpy as np
from matplotlib import pyplot as plt
from math import exp, sqrt,pi

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset,RandomSampler

from mpl_toolkits import mplot3d


In [None]:
epochs = 10000
device = torch.device("cpu")

eps1 = 1e-4
eps2 = 1e-2

learning_rate = 1e-3
batchflag = True
batchsize = 10

start = 0
end = 1
x = np.linspace(start,end,100 )
#x = np.meshgrid(x)
x = np.reshape(x, (np.size(x[:]),1))

A1 = ((eps1*np.pi**2+1)/(eps2**2*np.pi**2 + (eps1*np.pi**2 + 1)**2))
B1 = ((eps2*np.pi)/(eps2**2*np.pi**2 + (eps1*np.pi**2 + 1)**2))
muL = ((-eps2 + np.sqrt(eps2**2+4*eps1))/(2*eps1))
muR = ((eps2 + np.sqrt(eps2**2+4*eps1))/(2*eps1))
A2 = -A1*((1+np.exp(-muR))/(1-np.exp(-muL-muR)))
B2 = A1*((1+np.exp(-muL))/(1-np.exp(-muL-muR)))

def actual_soln():
    return A1*np.cos(np.pi*x) + B1*np.sin(np.pi*x) + A2*np.exp(-muL*x) + B2*np.exp(-muR*(1-x))

In [None]:
def plot_graph(soln,soln_name):
    x = np.linspace(start,end,100)
    plt.plot(x, soln)
    plt.title(soln_name)
    plt.show()

def plot_graphs(actual_soln,pred_soln):
	x = np.linspace(start,end,100)
	plt.plot(x, actual_soln, label = 'Actual solution')
	plt.plot(x, pred_soln, label = 'Predicted solution')
	plt.legend()
	plt.show()

	
class Swish(nn.Module):
	def __init__(self, inplace=True):
		super(Swish, self).__init__()
		self.inplace = inplace

	def forward(self, x):
		if self.inplace:
			x.mul_(torch.sigmoid(x))
			return x
		else:
			return x * torch.sigmoid(x)
	

class FBPINN(nn.Module):
	hid_dim = 512
	input_dim = 1 
	def __init__(self):
		super(FBPINN, self).__init__()
		self.tanh = nn.Tanh()
		self.lin0 = nn.Linear(self.input_dim,self.hid_dim)
		self.lin = nn.Linear(self.hid_dim,self.hid_dim)
		self.lin1 = nn.Linear(self.hid_dim,1)
		self.swish = Swish()
	def forward(self,x):		
		tanh1 = self.tanh(x)
		tanh2 = self.tanh(1 - x)
		tanh11 = (tanh1[:,0].unsqueeze(1))
		tanh22 = (tanh2[:,0].unsqueeze(1))
		x = self.lin0(x)
		x = self.swish(x)
		x = self.lin(x)
		x = self.swish(x)
		x = self.lin(x)
		x = self.swish(x)
		x = self.lin(x)
		x = self.swish(x)
		x = self.lin1(x)
		x = x*tanh11*tanh22
		return x


In [None]:
plot_graph(actual_soln(),"Actual solution")

In [None]:
def train(device,x,eps1,eps2,learning_rate,epochs,batch_flag,batch_size):
	
	xnet = torch.Tensor(x) 
	
	if(batch_flag):

		dataset = TensorDataset(xnet)
		dataloader = DataLoader(dataset, batch_size=batch_size,shuffle=True,num_workers = 0,drop_last = True )
		print(len(dataloader))
		
	net = FBPINN()#.to(device)
	
	def init_normal(m):
		if type(m) == nn.Linear:
			nn.init.kaiming_normal_(m.weight)

	net.apply(init_normal)

	optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas = (0.9,0.99),eps = 10**-15)

	def Loss_criterion(xnet):
		xnet.requires_grad = True
		points = xnet
		U = net(points)
		U = U.view(len(U),-1)
		
		#soln = 

		f = torch.cos(torch.pi*xnet)
		U_x = torch.autograd.grad(U,xnet,grad_outputs=torch.ones_like(xnet),create_graph = True,only_inputs=True)[0]
		U_xx = torch.autograd.grad(U_x,xnet,grad_outputs=torch.ones_like(xnet),create_graph = True,only_inputs=True)[0]
		loss1 = -eps1*U_xx + eps2*U_x + U - f  
		
		return nn.MSELoss()(loss1,torch.zeros_like(loss1)) 

	losses = []
	tic = time.time()

	if(batch_flag):
		for epoch in range(epochs):
			if epoch == 50:
				learning_rate = 0.00001
				new_optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas = (0.9,0.99),eps = 10**-15)
				optimizer = new_optimizer
			for batch_idx,(x_in) in enumerate(dataloader):
				#print(type(batch_idx))
				net.zero_grad()
				loss = Loss_criterion(x_in[0])
				loss.backward()
				optimizer.step() 
				if batch_idx % 20 ==0:
					print('\nTrain Epoch: {} \tLoss: {:.20f}'.format(epoch, loss.item()))

			U = net(xnet)
			z = U.detach().numpy()
			if epoch % 10 == 0:
				actual_loss = np.square(actual_soln() - z).mean()
				print('\nAfter Epoch {}, \t Actual solution loss: {:.10f}\n'.format(epoch, actual_loss))
			if epoch % 50 == 0:
				plot_graphs(actual_soln(),z)
			
#			losses.append([loss.item(),loss])

	toc = time.time()
	elapseTime = toc - tic
	print ("Time elapsed = ", elapseTime)

	output = net(xnet)  
	
	return output,losses 

In [None]:
output,losses = train(device,x,eps1,eps2,learning_rate,epochs,batchflag,batchsize)