In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AE(nn.Module):
    """autoencoder"""
    def __init__(self, encoder, decoder):
        """
        encoder, decoder : neural networks
        """
        super(AE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.own_optimizer = False

    def forward(self, x):
        z = self.encode(x)
        recon = self.decoder(z)
        return recon

    def encode(self, x):
        z = self.encoder(x)
        return z

    def recon_error(self, x):
        recon = self(x)
        recon_err = ((recon - x) ** 2).view(len(x), -1).mean(dim=1)
        return recon_err

    def reconstruct(self, x):
        return self(x)


    
# ConvNet desigend for 32x32 input
class ConvNet2(nn.Module):
    def __init__(self, in_chan=1, out_chan=64, nh=8, out_activation=None):
        """nh: determines the numbers of conv filters"""
        super(ConvNet2, self).__init__()
        self.conv1 = nn.Conv2d(in_chan, nh * 4, kernel_size=3, bias=True)
        self.conv2 = nn.Conv2d(nh * 4, nh * 8, kernel_size=3, bias=True)
        self.max1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(nh * 8, nh * 8, kernel_size=3, bias=True)
        self.conv4 = nn.Conv2d(nh * 8, nh * 16, kernel_size=3, bias=True)
        self.max2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.Conv2d(nh * 16, out_chan, kernel_size=4, bias=True)
        self.in_chan, self.out_chan = in_chan, out_chan
        self.out_activation = out_activation

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.max1(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.max2(x)
        x = self.conv5(x)
        if self.out_activation == 'tanh':
            x = torch.tanh(x)
        elif self.out_activation == 'sigmoid':
            x = torch.sigmoid(x)
        elif self.out_activation == 'softmax':
            x = F.log_softmax(x, dim=1)
        return x



class DeConvNet2(nn.Module):
    def __init__(self, in_chan=1, out_chan=1, nh=8, out_activation=None):
        """nh: determines the numbers of conv filters"""
        super(DeConvNet2, self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_chan, nh * 16, kernel_size=4, bias=True)
        self.conv2 = nn.ConvTranspose2d(nh * 16, nh * 8, kernel_size=3, bias=True)
        self.conv3 = nn.ConvTranspose2d(nh * 8, nh * 8, kernel_size=3, bias=True)
        self.conv4 = nn.ConvTranspose2d(nh * 8, nh * 4, kernel_size=3, bias=True)
        self.conv5 = nn.ConvTranspose2d(nh * 4, out_chan, kernel_size=3, bias=True)
        self.in_chan, self.out_chan = in_chan, out_chan
        self.out_activation = out_activation

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.conv5(x)
        if self.out_activation == 'sigmoid':
            x = torch.sigmoid(x)
        return x

In [29]:
D = 64
Z = 32
N = 100
X = torch.rand(N, D, 32, 32)
encoder = ConvNet2(in_chan=D, out_chan=Z, nh=4, out_activation='linear')
decoder = DeConvNet2(in_chan=Z, out_chan=D, nh=4, out_activation='sigmoid')
ae = AE(encoder, decoder)

recon = ae(X)
# assert recon.shape == (N, D, 28, 28)


# # opt = Adam(ae.parameters(), lr=1e-4)
# # d_loss = ae.train_step(X, opt)
# # assert 'loss' in d_loss



In [30]:
recon.shape

torch.Size([100, 64, 32, 32])

In [31]:
ae.recon_error(X).shape

torch.Size([100])