Skip to content

Commit

Permalink
FCNN example working
Browse files Browse the repository at this point in the history
  • Loading branch information
praksharma committed Nov 10, 2023
1 parent 4ed7f50 commit b779149
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 83 deletions.
5 changes: 5 additions & 0 deletions DeepINN/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def compile_network(self):
print("Network compiled", file=sys.stderr, flush=True)

def train(self, iterations : int = None, display_every : int = None):

if self.iter == 0: # We are running a fresh training
self.training_history = [] # Initialize an empty list for storing loss values
self.iterations = iterations
# Load all the seeds, data types, devices etc.
self.config.apply_seeds()
Expand Down Expand Up @@ -85,6 +87,9 @@ def train(self, iterations : int = None, display_every : int = None):
if self.iter % (self.iterations/10) == 0:
print(f"Iteration: {self.iter+1} \t BC Loss: {self.BC_loss:0.4f}\t PDE Loss: {self.PDE_loss:0.4f} \t Loss: {self.total_loss:0.4f}")

# Append the total loss value to the training history list
self.training_history.append(self.total_loss.item())

self.iter = self.iter + 1
else:
print('Training finished')
Expand Down
2 changes: 1 addition & 1 deletion DeepINN/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .FCNN import BaseNetwork, FullyConnected
from.utils import activation, initialiser
from .utils import activation, initialiser
2 changes: 1 addition & 1 deletion DeepINN/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .data import PointsDataset, PointsDataLoader, DeepONetDataLoader

from .user_fun import UserFunction
from .user_fun import UserFunction, tensor2numpy
from .plotting import plot, animate, scatter
from .evaluation import compute_min_and_max

Expand Down
8 changes: 5 additions & 3 deletions DeepINN/utils/plotting/scatter_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def scatter(subspace, *samplers, dpi=100, save=False):
Parameters
----------
subspace : torchphysics.problem.Space
subspace : dp.problem.Space
The (sub-)space of which the points should be plotted.
Only plotting for dimensions <= 3 is possible.
*samplers : torchphysics.problem.Samplers
*samplers : dp.problem.Samplers
The diffrent samplers for which the points should be plotted.
The plot for each sampler will be created in the order there were
passed in.
Expand Down Expand Up @@ -96,4 +96,6 @@ def _scatter_3D(points, labels, dpi, save):
ax.set_ylabel(labels[1])
ax.set_zlabel(labels[2])
if save:
plt.savefig('geom.jpg', dpi = dpi,bbox_inches='tight',transparent=True)
plt.savefig('geom.jpg', dpi = dpi,bbox_inches='tight',transparent=True)


8 changes: 7 additions & 1 deletion DeepINN/utils/user_fun.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,10 @@ def evaluate_function(self, device='cpu', **inp):
self.fun = self.fun.to(device)
return self.fun
else:
return torch.tensor(self.fun, device=device).float()
return torch.tensor(self.fun, device=device).float()

def tensor2numpy(tensor_list):
"""
Converts a list of torch.tensors to numpy arrays.
"""
return [tensor.detach().cpu().numpy() for tensor in tensor_list]
2 changes: 1 addition & 1 deletion Tutorials/3. Gradients/1. Gradients.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.8.10"
},
"orig_nbformat": 4
},
Expand Down
182 changes: 106 additions & 76 deletions Tutorials/5. FCNN/3. model.ipynb

Large diffs are not rendered by default.

0 comments on commit b779149

Please sign in to comment.