Skip to content

Commit

Permalink
Printing number of model parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Mar 12, 2019
1 parent 4347033 commit bb2eace
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
7 changes: 7 additions & 0 deletions digideep/policy/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
import torch
import torch.nn as nn
import numpy as np

class PolicyBase(object):
"""The base class for all policy classes. Policy is a model inside the agent which generates
Expand Down Expand Up @@ -39,6 +40,12 @@ def model_to_gpu(self):
gpu_count = 0
self.model.to(self.device) # dtype=model_type

def count_parameters(self):
"""
Counts the number of parameters in a PyTorch model.
"""
return np.sum(p.numel() for p in list(self.model.parameters()) if p.requires_grad)

@abstractmethod
def state_dict(self):
"""Returns state dict of the policy. It is ``model.state_dict`` by default, but child classes
Expand Down
1 change: 1 addition & 0 deletions digideep/policy/deterministic/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, device, **params):
self.averager["critic"] = Averager(self.model["critic"], self.model["critic_target"], **self.params["average_args"])

self.model_to_gpu()
logger("Number of parameters:\n>>>>>>", self.count_parameters())

def generate_actions(self, inputs, deterministic=False):
"""
Expand Down
2 changes: 2 additions & 0 deletions digideep/policy/stochastic/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .blocks import MLPBlock, RNNBlock, CNNBlock

from digideep.utility.toolbox import get_class #, get_module
from digideep.utility.logging import logger

from digideep.policy.base import PolicyBase

Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(self, device, obs_space, act_space, modelname, modelargs):
raise NotImplementedError("The action_space of the environment is not supported!")

self.model_to_gpu()
logger("Number of parameters:\n>>>>>>", self.count_parameters())


def generate_actions(self, inputs, hidden, masks, deterministic=False):
Expand Down

0 comments on commit bb2eace

Please sign in to comment.