In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [24]:
class Disentangler(nn.Module): 
    def __init__(self,encoder,decoder, transnet):
        super(Disentangler,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.transnet = transnet #estimates trans parameters, contains exponential weights, creates matrices
        
    def forward(self,x, x0=None):
        if x0 == None:
            x0 = self.encoder(x)
        x = self.transnet(x,x0)
        x = self.decoder(x)
        return x

class Encoder(nn.Module):
    def __init__(self, og_dim, latent_dim):
        assert latent_dim <= og_dim, 'latent space must have lower dimension'
        super(Encoder,self).__init__()
        self.og_dim = og_dim
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(og_dim, max(latent_dim, og_dim//16))
        self.fc2 = nn.Linear(max(latent_dim, og_dim//16), latent_dim)
    
    def forward(self,x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x

class Decoder(nn.Module):  
    def __init__(self, og_dim, latent_dim):
        assert latent_dim <= og_dim, 'latent space must have lower dimension'
        super(Decoder,self).__init__()
        self.og_dim = og_dim
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(latent_dim, max(latent_dim, og_dim//16))
        self.fc2 = nn.Linear(max(latent_dim, og_dim//16), og_dim)
    
    def forward(self,x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x
    
class Transnet(nn.Module):
    def __init__(self, og_dim, latent_dim, k_sparse):
        super(Transnet,self).__init__()
        self.og_dim = og_dim
        self.latent_dim = latent_dim
        self.ttl_dim = og_dim + latent_dim
        self.k_sparse = k_sparse
        self.fc1 = nn.Linear(ttl_dim, max(latent_dim, ttl_dim//8))
        self.fc2 = nn.Linear(max(latent_dim, ttl_dim//8), max(latent_dim, ttl_dim//32))
        self.fc3 = nn.Linear(max(latent_dim, ttl_dim//32), latent_dim)
    
    def k_mask(x,self):
        k = self.k_sparse
        topk, ix = torch.topk(x, k=k)
        return torch.zeros_like(x).scatter_(1,ix,topk)
    
    def forward(self,x,x0):
        x1 = torch.cat((x,x0),dim = 1) #create (B, N+M) tensor
        x1 = self.fc1(x1)
        x1 = F.relu(x1)
        x1 = self.fc2(x1)
        x1 = F.relu(x1)
        x1 = self.fc3(x1)
        x1 = F.relu(x1)
        return torch.exp(self.k_mask(x1))*x0
        
    

In [25]:
def make_model(og_dim, latent_dim, k_sparse):
    enc = Encoder(og_dim, latent_dim)
    dec = Decoder(og_dim, latent_dim)
    trans = Transnet(og_dim, latent_dim, k_sparse)
    model = Disentangler(enc,dec,trans)
    return model

In [None]:
def train()