In [1]:
% matplotlib inline

In [2]:
import os
import math

import numpy as np
from firelab import BaseTrainer
from firelab.utils import cudable

In [3]:
from torchtext import data
from torchtext.data import Field, Dataset, Example

batch_size = 16 #self.config.get('batch_size', 8)
project_path = '/home/skorokhodov/neuro_dostoevsky'

get_data_path = lambda x: os.path.join(project_path, 'data/shakespeare/%s.split.tok' % x)
modern_data_path = get_data_path('modern')
original_data_path = get_data_path('original')

with open(modern_data_path) as f: modern = f.read().splitlines()
with open(original_data_path) as f: original = f.read().splitlines()
    
text = Field(init_token='<bos>', eos_token='<eos>', batch_first=True)
fields = [('modern', text), ('original', text)]
examples = [Example.fromlist([m,o], fields) for m,o in zip(modern, original)]

dataset = Dataset(examples, fields)
text.build_vocab(dataset)
data_iter = data.BucketIterator(dataset, batch_size, repeat=False, shuffle=False)

In [23]:
import torch
import torch.nn as nn
from torch.optim import Adam

import torch.nn as nn


class RNNEncoder(nn.Module):
    def __init__(self, emb_size, hid_size, vocab_size):
        super(RNNEncoder, self).__init__()

        self.hid_size = hid_size
        self.embeddings = nn.Embedding(vocab_size, emb_size)
        self.gru = nn.GRU(emb_size, hid_size, batch_first=True)
        
        self.style_nn = nn.Sequential(
            nn.Linear(hid_size//2, hid_size//2),
            nn.SELU()
        )
        self.content_nn = nn.Sequential(
            nn.Linear(hid_size//2, hid_size//2),
            nn.SELU()
        )

    def forward(self, sentence):
        embeds = self.embeddings(sentence)
        _, last_hidden_state = self.gru(embeds)
        state = last_hidden_state.squeeze(0)
        
        style = self.style_nn(state[:, :self.hid_size//2])
        content = self.content_nn(state[:, self.hid_size//2:])

        return style, content


class RNNDecoder(nn.Module):
    def __init__(self, emb_size, hid_size, vocab_size):
        super(RNNDecoder, self).__init__()

        self.hid_size = hid_size
        self.embeddings = nn.Embedding(vocab_size, emb_size)
        self.gru = nn.GRU(emb_size, hid_size, batch_first=True)
        self.embs_to_logits = nn.Linear(hid_size, vocab_size)
        # self.embs_to_logits.weight = self.embeddings.weight # Sharing weights

    def forward(self, z, sentences):
        embs = self.embeddings(sentences)
        hid_states, _ = self.gru(embs, z.unsqueeze(0))
        logits = self.embs_to_logits(hid_states)

        return logits


class MLP(nn.Module):
    def __init__(self, size):
        super(MLP, self).__init__()
        
        self.nn = nn.Sequential(
            nn.Linear(size, size),
            nn.SELU(),
            nn.Linear(size, 1)
        )
    
    def forward(self, x):
        return self.nn(x)

#### Define models

In [24]:
emb_size = 512
hid_size = 512
voc_size = len(text.vocab)

encoder = cudable(RNNEncoder(emb_size, hid_size, voc_size))
decoder = cudable(RNNDecoder(emb_size, hid_size, voc_size))
critic = cudable(MLP(hid_size // 2))
motivator = cudable(MLP(hid_size // 2))

#### Define losses and optimizers

In [25]:
from itertools import chain
from torch.optim import Adam

# Reconstruction loss
weights = cudable(torch.ones(voc_size))
weights[text.vocab.stoi['<pad>']] = 0
rec_criterion = nn.CrossEntropyLoss(weights)

# Critic loss. Is similar to WGAN (but without lipschitz constraints)
class CriticLoss(nn.Module):
    def __init__(self):
        super(CriticLoss, self).__init__()
    
    def forward(self, real, fake):
        return real.mean() - fake.mean()

critic_criterion = CriticLoss()

# Motivator loss
motivator_criterion = nn.BCEWithLogitsLoss()

# Optimizers
critic_optim = Adam(critic.parameters(), lr=1e-4)
motivator_optim = Adam(motivator.parameters(), lr=1e-4)
ae_optim = Adam(chain(encoder.parameters(), decoder.parameters()), lr=1e-4)

In [31]:
for batch in data_iter:
    # Computing codes we need
    style_modern, content_modern = encoder(batch.modern)
    style_original, content_original = encoder(batch.original)
    
    # Now we should merge back style and content for decoder
    hid_modern = torch.cat([style_modern, content_modern], dim=1)
    hid_original = torch.cat([style_original, content_original], dim=1)
    
    # Ok, we now have all codes that we want
    # First, let's decode and compute reconstruction loss
    recs_modern = decoder(hid_modern, batch.modern[:, :-1])
    recs_original = decoder(hid_original, batch.original[:, :-1])
    
    rec_loss_modern = rec_criterion(recs_modern.view(-1, voc_size), batch.modern[:, 1:].contiguous().view(-1))
    rec_loss_original = rec_criterion(recs_original.view(-1, voc_size), batch.original[:, 1:].contiguous().view(-1))
    rec_loss = rec_loss_modern + rec_loss_original
    
    # Computing critic loss
    critic_loss = critic_criterion(content_modern, content_original)
    
    # Computing motivator loss
    motivator_logits_modern = motivator(style_modern)
    motivator_logits_original = motivator(style_original)
    motivator_loss_modern = motivator_criterion(motivator_logits_modern, torch.ones_like(motivator_logits_modern))
    motivator_loss_original = motivator_criterion(motivator_logits_original, torch.zeros_like(motivator_logits_original))
    motivator_loss = motivator_loss_modern + motivator_loss_original
    
    # Now we can make backward passes
    critic_optim.zero_grad()
    critic_loss.backward(retain_graph=True)
    critic_optim.step()

    ae_optim.zero_grad()
    motivator_optim.zero_grad()
    motivator_loss.backward(retain_graph=True)
    motivator_optim.step()
    ae_optim.step()
    
    ae_optim.zero_grad()
    rec_loss.backward(retain_graph=True)
    ae_optim.step()
    
    total_loss = rec_loss + critic_loss + motivator_loss
    print('Loss: {:.02f}'.format(total_loss.item()))
    
    break

Loss: 11.47
