In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [31]:
# https://discuss.pytorch.org/t/how-to-share-weights-between-two-layers/55541/2


def init_weights():
    pass


class ImageAutoEncoder(nn.Module):

    def __init__(self, tie_weights=True):
        super().__init__()

        # encoder
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.fc3 = nn.Linear(25600, 100, bias=True)

        # decoder
        self.fc4 = nn.Linear(100, 25600, bias=True)
        self.conv5 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)

        # utils
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.relu = nn.ReLU()

        # share encoder decoder weight matrices
        if tie_weights:
            self._tie_weights()

    def _tie_weights(self):
        self.fc4.weight.data = self.fc3.weight.data.transpose(0,1)
        self.conv5.weight.data = self.conv2.weight.data.transpose(0,1)
        self.conv6.weight.data = self.conv1.weight.data.transpose(0,1)

    def forward(self, x):
        # encoder
        h = self.relu(self.conv1(x))
        h = self.relu(self.conv2(h))
        h = self.pool(h)
        h = self.fc3(h.reshape(-1, 25600))
        print(h.shape)

        # decoder
        h = self.fc4(h).T
        h = h.reshape(-1, 64, 20, 20)
        h = self.upsample(h)
        h = self.conv5(h)
        x_hat = self.conv6(h)
        return x_hat


class TextAutoEncoder(nn.Module):

    def __init__(self, tie_weights=True):
        super().__init__()

        # encoder
        self.fc1 = nn.Linear(32, 400, bias=True)
        self.fc2 = nn.Linear(400, 100, bias=True)

        # decoder
        self.fc3 = nn.Linear(100, 400, bias=True)
        self.fc4 = nn.Linear(400, 32, bias=True)

        # utils
        self.relu = nn.ReLU()

        # share encoder decoder weight matrices
        if tie_weights:
            self._tie_weights()

    def _tie_weights(self):
        self.fc3.weight.data = self.fc2.weight.data.transpose(0,1)
        self.fc4.weight.data = self.fc1.weight.data.transpose(0,1)

    def forward(self, x):
        # encoder
        h = self.relu(self.fc1(x))
        h = self.fc2(h)
        print(h.shape)

        # decoder
        h = self.relu(self.fc3(h))
        x_hat = self.fc4(h)
        return x_hat

model = TextAutoEncoder()
print(sum(p.numel() for p in model.parameters()))
print(model)

x = torch.randn(1, 32)
print(x.shape)
model(x).shape

106532
TextAutoEncoder(
  (fc1): Linear(in_features=32, out_features=400, bias=True)
  (fc2): Linear(in_features=400, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=32, bias=True)
  (relu): ReLU()
)
torch.Size([1, 32])
torch.Size([1, 100])


torch.Size([1, 32])

In [32]:
model = ImageAutoEncoder()
print(sum(p.numel() for p in model.parameters()))
print(model)

x = torch.randn(1, 3, 40, 40)
print(x.shape)
model(x).shape

5223079
ImageAutoEncoder(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc3): Linear(in_features=25600, out_features=100, bias=True)
  (fc4): Linear(in_features=100, out_features=25600, bias=True)
  (conv5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (upsample): Upsample(scale_factor=2.0, mode='bilinear')
  (relu): ReLU()
)
torch.Size([1, 3, 40, 40])
torch.Size([1, 100])


torch.Size([1, 3, 40, 40])