Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrodiamel committed Jul 21, 2019
1 parent ac68081 commit a14672f
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions pytvision/netmodels/ferattentionnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
__all__ = ['FERAttentionNet', 'ferattention', 'FERAttentionSTNNet', 'ferattentionstn' ]


def weights_init(model):
for m in model.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight.data)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()


def ferattention(pretrained=False, **kwargs):
""""FERAttention model architecture
"""
Expand All @@ -29,7 +38,8 @@ def ferattention(pretrained=False, **kwargs):
state = torch.load('../../ferattention/out/attnet/att_attgmmnet_ferattentiongmm_attloss_adam_affectnetdark_dim64_resnet18x32_fold0_mixup_retrain_000/models/chk000390.pth.tar')
#model.load_state_dict( state['state_dict'] )
utl.load_state_dict(model.state_dict(), state['state_dict'] )
model.netclass.weights_init()
#model.netclass.weights_init()
weights_init( model.netclass )
#pass
return model

Expand All @@ -45,7 +55,8 @@ def ferattentionstn(pretrained=False, **kwargs):
state = torch.load('../../ferattention/out/attnet/att_attstnnet_ferattentionstn_attloss_adam_affectnetdark_dim64_resnet18x32_fold0_mixup_retrain_000/models/chk000025.pth.tar')
#model.load_state_dict( state['state_dict'] )
utl.load_state_dict(model.state_dict(), state['state_dict'] )
model.netclass.weights_init()
#model.netclass.weights_init()
weights_init( model.netclass )
#pass
return model

Expand Down Expand Up @@ -264,6 +275,7 @@ def __init__(self, num_classes=8, num_channels=3, num_filters=32 ):
super().__init__()
self.num_classes = num_classes
self.num_filters = num_filters
self.size_input=64

#////////
#attention module
Expand Down Expand Up @@ -356,6 +368,7 @@ def __init__(self, num_classes=8, num_channels=3, num_filters=32 ):
super().__init__()
self.num_classes = num_classes
self.num_filters = num_filters
self.size_input=64


#////////
Expand Down

0 comments on commit a14672f

Please sign in to comment.