In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import csv
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [0]:
class GatedShortcutConnection(nn.Module):
  def __init__(self, n_ch=64):
    super().__init__()
    self.convA = nn.Conv2d(n_ch * 2, n_ch * 2, 1)
    self.convB = nn.Conv2d(n_ch * 2, n_ch * 2, 1)
    self.sigm = nn.Sigmoid()

  def forward(self, x):
    A = self.convA(x)
    B = self.convB(x)
    z = A * self.sigm(B)
    return z

class ResidualStack(nn.Module):
  def __init__(self, n_ch=64):
    super().__init__()
    self.layers = nn.ModuleList()
    n_in_ch = n_ch * 4
    for _ in range(5):
      self.layers.append(nn.ReLU())
      self.layers.append(nn.Conv2d(n_in_ch, n_ch, (3,3), (1,1)))
      n_in_ch = n_ch * 2
      self.layers.append(nn.ReLU())
      self.layers.append(nn.Conv2d(n_ch, n_ch * 2, (3,3), (1,1)))
      self.layers.append(nn.ReLU())
      self.layers.append(GatedShortcutConnection())
      self.layers.append(nn.ReLU())
    self.relu = nn.ReLU()
  
  def forward(self, x):
    z = self.layers(x)
    z = self.relu(z)
    return z

In [17]:
ResidualStack()

ResidualStack(
  (layers): ModuleList(
    (0): ReLU()
    (1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1))
    (2): ReLU()
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): GatedShortcutConnection(
      (convA): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      (convB): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      (sigm): Sigmoid()
    )
    (6): ReLU()
    (7): ReLU()
    (8): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
    (9): ReLU()
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (11): ReLU()
    (12): GatedShortcutConnection(
      (convA): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      (convB): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      (sigm): Sigmoid()
    )
    (13): ReLU()
    (14): ReLU()
    (15): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
    (16): ReLU()
    (17): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (18): ReLU()
    (19): GatedShort

In [0]:
class VAEEncoder(nn.Module):
    def __init__(self, n_ch):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Conv2d(3, n_ch * 2, (4,4), (2,2)))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.Conv2d(n_ch * 2, n_ch * 4, (4,4), (2,2)))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.Conv2d(n_ch * 4, n_ch * 4, (3,3), (1,1)))
        self.layers.append(ResidualStack())
        
    def encode(self, x, muf, sigmaf):
        z = self.forward(x)
        mu = muf(z)
        sigma = sigmaf(z)
        return mu, torch.exp(sigma)
        
    def forward(self, x):
        z = self.layers(x)
        return z

class VAEDecoder(nn.Module):
    def __init__(self, n_ch):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Conv2d(n_ch * 2, n_ch * 4, (3,3), (1,1)))
        self.layers.append(ResidualStack())
        self.layers.append(nn.ConvTranspose2d(n_ch * 2, n_ch * 2, (4,4), (2,2)))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.ConvTranspose2d(n_ch * 2, 6, (4,4), (2,2)))
        
    def encode(self, x, muf, sigmaf):
        z = self.forward(x)
        mu = muf(z)
        sigma = sigmaf(z)
        return mu, torch.exp(sigma)
        
    def forward(self, x):
        z = self.layers(x)
        return z
        
class VAE(nn.Module):
    def __init__(self, latd=2, ind=2, lin_dim=64):
        super().__init__()
        self.latd = latd
        self.ind = ind
        self.encoder = VAEEncoder(64)
        self.decoder = VAEDecoder(64)
        self.mu_encode = nn.Linear(lin_dim, ind)
        self.mu_decode = nn.Linear(lin_dim, latd)
        self.sigma_encode = nn.Linear(lin_dim, ind)
        self.sigma_decode = nn.Linear(lin_dim, latd)
        self.normd = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
        
    def sample(self, mu, sigma):
        if not self.training:
            return mu
        standard = self.normd.sample(mu.shape).to(device).view(mu.shape)
        return mu + standard * torch.sqrt(sigma)
    
    def forward(self, x):
        mu_z, sigma_z = self.encoder.encode(x, self.mu_encode, self.sigma_encode)
        z = self.sample(mu_z, sigma_z)
        mu_x, sigma_x = self.decoder.decode(z, self.mu_decode, self.sigma_decode)
        return mu_z, sigma_z, mu_x, sigma_x
        

In [22]:
VAE()

VAE(
  (encoder): VAEEncoder(
    (layers): ModuleList(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2))
      (1): ReLU()
      (2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      (5): ResidualStack(
        (layers): ModuleList(
          (0): ReLU()
          (1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1))
          (2): ReLU()
          (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
          (4): ReLU()
          (5): GatedShortcutConnection(
            (convA): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
            (convB): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
            (sigm): Sigmoid()
          )
          (6): ReLU()
          (7): ReLU()
          (8): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
          (9): ReLU()
          (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
          (11): ReLU()
          (12): G