Skip to content

Commit

Permalink
Added channels argument to specify number of input channels.
Browse files Browse the repository at this point in the history
  • Loading branch information
Podgorskiy committed Aug 30, 2018
1 parent 0ede3ee commit 98dff89
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions net.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def weight_init(self, mean, std):

class Generator(nn.Module):
# initializers
def __init__(self, z_size, d=128):
def __init__(self, z_size, d=128, channels=1):
super(Generator, self).__init__()
self.deconv1_1 = nn.ConvTranspose2d(z_size, d*2, 4, 1, 0)
self.deconv1_1_bn = nn.BatchNorm2d(d*2)
Expand All @@ -73,7 +73,7 @@ def __init__(self, z_size, d=128):
self.deconv2_bn = nn.BatchNorm2d(d*2)
self.deconv3 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
self.deconv3_bn = nn.BatchNorm2d(d)
self.deconv4 = nn.ConvTranspose2d(d, 1, 4, 2, 1)
self.deconv4 = nn.ConvTranspose2d(d, channels, 4, 2, 1)

# weight_init
def weight_init(self, mean, std):
Expand All @@ -86,15 +86,14 @@ def forward(self, input):#, label):
x = F.relu(self.deconv2_bn(self.deconv2(x)))
x = F.relu(self.deconv3_bn(self.deconv3(x)))
x = F.tanh(self.deconv4(x)) * 0.5 + 0.5

return x


class Discriminator(nn.Module):
# initializers
def __init__(self, d=128):
def __init__(self, d=128, channels=1):
super(Discriminator, self).__init__()
self.conv1_1 = nn.Conv2d(1, d//2, 4, 2, 1)
self.conv1_1 = nn.Conv2d(channels, d//2, 4, 2, 1)
self.conv2 = nn.Conv2d(d // 2, d*2, 4, 2, 1)
self.conv2_bn = nn.BatchNorm2d(d*2)
self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
Expand All @@ -117,9 +116,9 @@ def forward(self, input):

class Encoder(nn.Module):
# initializers
def __init__(self, z_size, d=128):
def __init__(self, z_size, d=128, channels=1):
super(Encoder, self).__init__()
self.conv1_1 = nn.Conv2d(1, d//2, 4, 2, 1)
self.conv1_1 = nn.Conv2d(channels, d//2, 4, 2, 1)
self.conv2 = nn.Conv2d(d // 2, d*2, 4, 2, 1)
self.conv2_bn = nn.BatchNorm2d(d*2)
self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
Expand Down

0 comments on commit 98dff89

Please sign in to comment.