Skip to content

Commit

Permalink
NN refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
praksharma committed Mar 11, 2024
1 parent a829089 commit 15853e1
Show file tree
Hide file tree
Showing 6 changed files with 432 additions and 23 deletions.
4 changes: 4 additions & 0 deletions DeepINN/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ def __init__(self, float_type=torch.float32, random_seed=42, device = 'cuda'):
self.device = device
# Add more configuration parameters as needed

self.apply_seeds()
self.apply_float_type()
self.default_device()

def apply_seeds(self):
torch.manual_seed(self.random_seed)

Expand Down
8 changes: 4 additions & 4 deletions DeepINN/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def initialise_training(self, iterations : int = None):
self.training_history = [] # Initialize an empty list for storing loss values
self.iterations = iterations
# Load all the seeds, data types, devices etc.
self.config.apply_seeds()
self.config.apply_float_type()
self.config.default_device()
# self.config.apply_seeds()
# self.config.apply_float_type()
# self.config.default_device()

# In 1D problem we need to combine the BCs as there is only one point for each BC, which returns an undefined feature scaling because the ub and lb are same in the denominator, so we get infinity
# For problem with multiple points on each boundary, we don't need to combine them.
Expand All @@ -74,7 +74,7 @@ def train(self, iterations : int = None, display_every : int = 1):
"""
self.initialise_training(iterations)
self.trainer()

@timer
def trainer(self):
# implement training loop
Expand Down
6 changes: 3 additions & 3 deletions DeepINN/nn/FCNN.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from .base import BaseNetwork

class FullyConnected(BaseNetwork, torch.nn.Module):
class FullyConnected(BaseNetwork):
"""
Implementation of Fully Connected neural network
"""
Expand All @@ -28,10 +28,10 @@ def __init__(self,
self.linears = torch.nn.ModuleList([torch.nn.Linear(self.layer_size[i], self.layer_size[i+1]) for i in range(0,len(self.layer_size)-1)])

# initialise the weights
self.init()
self.weight_init()


def init(self):
def weight_init(self):
# weight initialisation
for i in range(len(self.layer_size)-1):

Expand Down
5 changes: 3 additions & 2 deletions DeepINN/nn/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from .utils import activation, initialiser
import torch

class BaseNetwork():
class BaseNetwork(torch.nn.Module):
"""
Base class for all neural networks
"""
def __init__(self) -> None:
super().__init__() # intialise all methods from nn.Module
super().__init__() # initialise all methods from nn.Module

self.activation_function = activation
self.initialiser_function = initialiser
Expand Down
Loading

0 comments on commit 15853e1

Please sign in to comment.