In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction import stop_words
from string import punctuation
from collections import Counter
import numpy as np
stop_words = list(stop_words.ENGLISH_STOP_WORDS) + list(punctuation)

In [35]:
hparams = {
    'n_features' : 10000,
    'hidden_dim' : 128,
    'latent_dims' : [10, 10, 10],
    'batch_size' : 16,
    'lr' : 0.001,
    'dropout' : 0.5,
    'n_epochs' : 25,
    'clf_loss_weight' : 1.0,
    'latent_loss_weight' : 1.0,
    'recon_loss_weight' : 1.0,
    'display_interval' : 100,
    'val_interval' : 1000
}

In [3]:
train = fetch_20newsgroups(subset='train')
test = fetch_20newsgroups(subset='test')

Downloading 20news dataset. This may take a few minutes.
Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)


In [4]:
tfidf = TfidfVectorizer(max_features=hparams['n_features'], stop_words=stop_words)
train_features = tfidf.fit_transform(train.data)
test_features = tfidf.transform(test.data)

In [5]:
def straight_through_estimator(logits):
    argmax = torch.eq(logits, logits.max(-1, keepdim=True).values).to(logits.dtype)
    return (argmax - logits).detach() + logits

def gumbel_softmax(logits, temperature=1.0, eps=1e-20):
    u = torch.rand(logits.size(), dtype=logits.dtype, device=logits.device)
    g = -torch.log(-torch.log(u + eps) + eps)
    return F.softmax((logits + g) / temperature, dim=-1)

class CategoricalLayer(nn.Module):
    def __init__(self, input_dim, categorical_dim, output_dim=None):
        super().__init__()
        
        if output_dim == None:
            output_dim = input_dim
            
        self.dense_in = nn.Linear(input_dim, categorical_dim, bias=True)
        self.dense_out = nn.Linear(input_dim+categorical_dim, output_dim, bias=True)
        
    def forward(self, inputs, straight_through=True, sample=False, temperature=1.0, return_logits=False):
        logits = self.dense_in(inputs)
        
        if sample:
            dist = gumbel_softmax(logits, temperature=temperature)
        else:
            dist = F.softmax(logits, dim=-1)
            
        if straight_through:
            dist = straight_through_estimator(dist)
            
        h = torch.tanh(self.dense_out(torch.cat([inputs, dist], dim=-1)))
        
        if return_logits:
            return h, dist, logits
        else:
            return h, dist
    
class HLGC(nn.Module):
    def __init__(self, n_classes, input_dim, categorical_dims, hidden_dim=128, dropout_rate=0.5, batch_size=16,
                 n_epochs=25):
        super().__init__()
        
        self.n_classes = n_classes
        self.input_dim = input_dim
        self.categorical_dims = categorical_dims
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        
        self._build_model()
        
    def _build_model(self):
        # classifier
        self.input_dense = nn.Linear(self.input_dim, self.hidden_dim, bias=True)
        
        self.categorical_layers = nn.ModuleList([
            CategoricalLayer(self.hidden_dim, dim) for dim in self.categorical_dims
        ])
        
        self.global_dense = nn.Linear(sum(self.categorical_dims), self.hidden_dim, bias=True)
        self.out_dense = nn.Linear(self.hidden_dim, self.n_classes, bias=True)
        self.dropout = nn.Dropout(self.dropout_rate)
        
        # generator
        self.encoder = nn.ModuleList([
            nn.Linear(self.hidden_dim+dim, self.hidden_dim, bias=True) for dim in self.categorical_dims
        ])
        self.encoder_out = nn.Linear(self.hidden_dim, self.n_classes, bias=True) 
        
        self.decoder_in = nn.Linear(self.n_classes, self.hidden_dim, bias=True)
        self.decoder = nn.ModuleList([
            CategoricalLayer(self.hidden_dim, dim) for dim in self.categorical_dims
        ])
        
    def encode(self, dists):
        h = torch.zeros(dists[0].size(0), self.hidden_dim, device=dists[0].device) 
        
        for dist, layer in zip(dists, self.encoder):
            h = torch.tanh(layer(torch.cat([h, dist], dim=-1))) 
        z = self.encoder_out(h)
        return z
    
    def generate(self, z_sample, straight_through=True, temperature=1.0): 
        h = torch.tanh(self.decoder_in(z_sample))
        
        gen_states, gen_logits = [], []
        for layer in self.decoder:
            h, dist, logits = layer(
                h, straight_through=straight_through, temperature=temperature, sample=True, 
                return_logits=True
            )
            gen_states.append(dist)
            gen_logits.append(logits)
        return gen_states, gen_logits
    
    def classify(self, inputs, return_states=False):
        h = torch.tanh(self.input_dense(inputs))
        self.dropout(h)
        
        states = []
        for layer in self.categorical_layers:
            h, dist = layer(h, straight_through=True, sample=False)
            self.dropout(h)
            states.append(dist)
            
        h = torch.tanh(self.global_dense(torch.cat(states, dim=-1)))
        logits = self.out_dense(h)
        
        if return_states:
            return logits, states
        else:
            return logits
        
    def forward(self, inputs, temperature=1.0):
        # classifier
        clf_logits, clf_states = self.classify(inputs, return_states=True)
        
        # generator
        z = self.encode([x.detach() for x in clf_states])
        z_sample = straight_through_estimator(gumbel_softmax(z, temperature=temperature))
        gen_states, gen_logits = self.generate(z_sample, straight_through=True, temperature=temperature)
        return clf_logits, clf_states, gen_logits, gen_states, z
    
    def fit(self, train_features, train_targets, val_features=None, val_targets=None):
        train_loader = DataLoader(
            list(zip(train_features, train_targets)), 
            batch_size=self.batch_size, 
            shuffle=True
        )
        
        for epoch in range(self.n_epochs):
            for batch in train_loader:
                features, targets = batch

In [6]:
model = HLGC(20, hparams['n_features'], hparams['latent_dims'], hidden_dim=hparams['hidden_dim'], 
             dropout_rate=hparams['dropout'], batch_size=hparams['batch_size'], n_epochs=hparams['n_epochs'])

In [7]:
class_counts = Counter(train.target)
class_weights = torch.tensor([class_counts[i] for i in range(len(class_counts))], dtype=torch.float32)

In [8]:
train_loader = DataLoader(
    list(zip(train_features.toarray().astype(np.float32), train.target)), 
    batch_size=hparams['batch_size'], 
    shuffle=True
)

In [19]:
optimizer = optim.Adam(
    filter(lambda x: x.requires_grad, model.parameters()), betas=(0.9, 0.98),
    eps=1e-09,
    lr=hparams['lr']
)

In [None]:
for epoch in range(hparams['n_epochs']):
    itr = 0
    for batch in train_loader:
        model.train()
        itr += 1

        features, targets = batch
        clf_logits, clf_states, gen_logits, gen_states, z = model(features)

        clf_loss = nn.CrossEntropyLoss(weight=class_weights)(clf_logits, targets)

        recon_loss = 0.0
        for clf_state, gen_logit in zip(clf_states, gen_logits):
            recon_loss += nn.CrossEntropyLoss()(gen_logit, clf_state.argmax(-1))
            recon_loss /= len(clf_state)

        latent_loss = nn.CrossEntropyLoss(weight=class_weights)(z, clf_logits.argmax(-1))    

        loss = (
            hparams['clf_loss_weight']*clf_loss + 
            hparams['recon_loss_weight']*recon_loss +
            hparams['latent_loss_weight']*latent_loss
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()
        optimizer.zero_grad()

        if itr == 1 or itr % hparams['display_interval'] == 0:
            clf_acc = (targets == clf_logits.argmax(-1)).to(float).mean()
            
            recon_acc = 0.0
            for clf_state, gen_state in zip(clf_states, gen_states):
                recon_acc += (clf_state.argmax(-1) == gen_state.argmax(-1)).to(float).mean()
            recon_acc /= len(clf_states)
            
            log_string = '[{}, {:5d}] loss - (total : {:3f}, clf : {:3f}, latent : {:3f}, recon : {:3f}), \
acc - (clf : {:3f}, recon : {:3f})'.format(epoch, itr, loss.item(), clf_loss.item(), latent_loss.item(), 
                                           recon_loss.item(), clf_acc.item(), recon_acc.item())
            print(log_string)