Skip to content

Commit

Permalink
Merge pull request #31 from tfjgeorge/bn1d_test
Browse files Browse the repository at this point in the history
adds a linear batchnorm layer in tests
  • Loading branch information
tfjgeorge committed Sep 24, 2021
2 parents 522a9a4 + 715ce59 commit 8f96563
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,31 @@ def __init__(self, normalization='none'):
if self.normalization == 'batch_norm':
self.bn1 = nn.BatchNorm2d(6)
elif self.normalization == 'group_norm':
self.gn = nn.GroupNorm(2, 6)
self.gn1 = nn.GroupNorm(2, 6)
self.conv2 = nn.Conv2d(6, 5, 4, 1)
self.conv3 = nn.Conv2d(5, 7, 3, 1, 1)
self.fc1 = nn.Linear(1*1*7, 3)
self.fc1 = nn.Linear(1*1*7, 4)
if self.normalization == 'batch_norm':
self.bn2 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 3)

def forward(self, x):
if self.normalization == 'batch_norm':
x = tF.relu(self.bn1(self.conv1(x)))
elif self.normalization == 'group_norm':
x = tF.relu(self.gn(self.conv1(x)))
x = tF.relu(self.gn1(self.conv1(x)))
else:
x = tF.relu(self.conv1(x))
x = tF.max_pool2d(x, 2, 2)
x = tF.relu(self.conv2(x))
x = tF.max_pool2d(x, 2, 2)
x = tF.relu(self.conv3(x))
x = x.view(-1, 1*1*7)
return self.fc1(x)
x = self.fc1(x.view(-1, 1*1*7))
if self.normalization == 'batch_norm':
x = self.fc2(self.bn2(x))
else:
x = self.fc2(x)
return x


class LinearFCNet(nn.Module):
Expand Down

0 comments on commit 8f96563

Please sign in to comment.