In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [40]:
from fastai import *
from fastai.text import *
from fastai.imports import *


from utils import Logger

import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader


import pandas as pd
from pandas import Series, DataFrame

In [41]:
path = Path('../data/lyrics')

In [46]:
data = pd.read_csv("./data/lyrics/lyrics_preprocessed.csv")["0"]
data

0                                           Your man on the road, he doin promo
1                                  You said, "Keep our business on the low-low"
2                                     Im just tryna get you out the friend zone
3                                    Cause you look even better than the photos
4                                      I cant find your house, send me the info
                                          ...                                  
40819                                                 Scene goes frame by frame
40820                                                       Who swayed to suede
40821                                               Youll hear it all dont call
40822                                  Throw away everything Ive written you oh
40823    Anything just like every single thing todayEmbedShare URLCopyEmbedCopy
Name: 0, Length: 40824, dtype: object

In [53]:
defaults.device = torch.device('cpu') 
dataloader = DataLoader(data, batch_size= 64)

In [54]:
class TextGANModule(nn.Module):
    "Wrapper around a `generator` and a `critic` to create a GAN."
    def __init__(self, generator:nn.Module=None, critic:nn.Module=None, gen_mode:bool=False):
        super().__init__()
        self.gen_mode = gen_mode
        if generator: self.generator,self.critic = generator,critic

    def forward(self, *args):
        return self.generator(*args)[0] if self.gen_mode else self.critic(*args)

    def switch(self, gen_mode:bool=None):
        "Put the model in generator mode if `gen_mode`, in critic mode otherwise."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode

In [55]:
class TextGANTrainer(LearnerCallback):
    "Handles GAN Training."
    _order=-20
    def __init__(self, learn:Learner, switch_eval:bool=False, clip:float=None, beta:float=0.98, gen_first:bool=False):
        super().__init__(learn)
        self.switch_eval,self.clip,self.beta,self.gen_first = switch_eval,clip,beta,gen_first
        self.generator,self.critic = self.model.generator,self.model.critic

    def _set_trainable(self):
        train_model = self.generator if     self.gen_mode else self.critic
        loss_model  = self.generator if not self.gen_mode else self.critic
        requires_grad(train_model, True)
        requires_grad(loss_model, False)
        if self.switch_eval:
            train_model.train()
            loss_model.eval()

    def on_train_begin(self, **kwargs):
        "Create the optimizers for the generator and critic if necessary, initialize smootheners."
        if not getattr(self,'opt_gen',None):
            self.opt_gen = self.opt.new([nn.Sequential(*flatten_model(self.generator))])
        else: self.opt_gen.lr,self.opt_gen.wd = self.opt.lr,self.opt.wd
        if not getattr(self,'opt_critic',None):
            self.opt_critic = self.opt.new([nn.Sequential(*flatten_model(self.critic))])
        else: self.opt_critic.lr,self.opt_critic.wd = self.opt.lr,self.opt.wd
        self.gen_mode = self.gen_first
        self.switch(self.gen_mode)
        self.closses,self.glosses = [],[]
        self.smoothenerG,self.smoothenerC = SmoothenValue(self.beta),SmoothenValue(self.beta)
        self.recorder.no_val=False
        self.recorder.add_metric_names(['gen_loss', 'disc_loss'])

    def on_train_end(self, **kwargs):
        "Switch in generator mode for showing results."
        self.switch(gen_mode=True)

    def on_batch_begin(self, last_input, last_target, **kwargs):
        "Clamp the weights with `self.clip` if it's not None, return the correct input."
        if self.clip is not None:
            for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
        return {'last_input':last_input,'last_target':last_target} if self.gen_mode else {'last_input':last_target,'last_target':last_input}

    def on_backward_begin(self, last_loss, last_output, **kwargs):
        "Record `last_loss` in the proper list."
        last_loss = last_loss.detach().cpu()
        if self.gen_mode:
            self.smoothenerG.add_value(last_loss)
            self.glosses.append(self.smoothenerG.smooth)
            self.last_gen = last_output.detach().cpu()
        else:
            self.smoothenerC.add_value(last_loss)
            self.closses.append(self.smoothenerC.smooth)

    def on_epoch_begin(self, epoch, **kwargs):
        "Put the critic or the generator back to eval if necessary."
        self.switch(self.gen_mode)

    def on_epoch_end(self, pbar, epoch, last_metrics, **kwargs):
        "Put the various losses in the recorder"
        return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])

    def switch(self, gen_mode:bool=None):
        "Switch the model, if `gen_mode` is provided, in the desired mode."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
        self.opt.opt = self.opt_gen.opt if self.gen_mode else self.opt_critic.opt
        self._set_trainable()
        self.model.switch(gen_mode)
        self.loss_func.switch(gen_mode)

NameError: name 'LearnerCallback' is not defined

In [56]:
class TextGANLoss(TextGANModule):
    "Wrapper around `loss_funcC` (for the critic) and `loss_funcG` (for the generator)."
    def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:TextGANModule):
        super().__init__()
        self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model

    def generator(self, output, target):
        "Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`."
        fake_pred = self.gan_model.critic(seq_gumbel_softmax(output))
        return self.loss_funcG(fake_pred, target, output)

    def critic(self, real_pred, input):
        "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`."
        fake = self.gan_model.generator(input.requires_grad_(False))[0].requires_grad_(True)
        fake_pred = self.gan_model.critic(seq_gumbel_softmax(fake))
        return self.loss_funcC(real_pred, fake_pred)

NameError: name 'Callable' is not defined

In [27]:
class FixedTextGANSwitcher(LearnerCallback):
    "Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
    def __init__(self, learn:Learner, n_crit:Union[int,Callable]=1, n_gen:Union[int,Callable]=1):
        super().__init__(learn)
        self.n_crit,self.n_gen = n_crit,n_gen

    def on_train_begin(self, **kwargs):
        "Initiate the iteration counts."
        self.n_c,self.n_g = 0,0

    def on_batch_end(self, iteration, **kwargs):
        "Switch the model if necessary."
        if self.learn.gan_trainer.gen_mode:
            self.n_g += 1
            n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
        else:
            self.n_c += 1
            n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
        if target == n_out:
            self.learn.gan_trainer.switch()
            self.n_c,self.n_g = 0,0

In [28]:
class TextGANLearner(Learner):
    "A `Learner` suitable for GANs."
    def __init__(self, data:DataBunch, generator:nn.Module, critic:nn.Module, gen_loss_func:LossFunction,
                 crit_loss_func:LossFunction, switcher:Callback=None, gen_first:bool=False, switch_eval:bool=True,
                 clip:float=None, **learn_kwargs):
        gan = TextGANModule(generator, critic)
        loss_func = TextGANLoss(gen_loss_func, crit_loss_func, gan)
        switcher = ifnone(switcher, partial(FixedTextGANSwitcher, n_crit=5, n_gen=1))
        super().__init__(data, gan, loss_func=loss_func, callback_fns=[switcher], **learn_kwargs)
        trainer = TextGANTrainer(self, clip=clip, switch_eval=switch_eval)
        self.gan_trainer = trainer
        self.callbacks.append(trainer)

    @classmethod
    def from_learners(cls, learn_gen:Learner, learn_crit:Learner, switcher:Callback=None,
                      weights_gen:Tuple[float,float]=None, **learn_kwargs):
        "Create a GAN from `learn_gen` and `learn_crit`."
        losses = gan_loss_from_func(learn_gen.loss_func, learn_crit.loss_func, weights_gen=weights_gen)
        return cls(learn_gen.data, learn_gen.model, learn_crit.model, *losses, switcher=switcher, **learn_kwargs)

    @classmethod
    def wgan(cls, data:DataBunch, generator:nn.Module, critic:nn.Module, switcher:Callback=None, clip:float=0.01, **learn_kwargs):
        "Create a WGAN from `data`, `generator` and `critic`."
        return cls(data, generator, critic, NoopLoss(), WassersteinLoss(), switcher=switcher, clip=clip, **learn_kwargs)

In [29]:
def lm_loss(input, target, kld_weight=0):
    sl, bs = target.size()
    sl_in,bs_in,nc = input.size()
    return F.cross_entropy(input.view(-1,nc), target.view(-1))

In [30]:
def bn_drop_lin(n_in, n_out, bn=True, initrange=0.01,p=0, bias=True, actn=nn.LeakyReLU(inplace=True)):
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if p != 0: layers.append(nn.Dropout(p))
    linear = nn.Linear(n_in, n_out, bias=bias)
    if initrange:linear.weight.data.uniform_(-initrange, initrange)
    if bias: linear.bias.data.zero_()
    layers.append(linear)
    if actn is not None: layers.append(actn)
    return layers

In [14]:
defaults.device

device(type='cpu')

In [55]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f9914feb8d0>

In [74]:
learn = language_model_learner(data_lm, arch=AWD_LSTM)
learn.load('models/poems_fine_tuned',device=torch.device('cpu'));

In [75]:
learn.unfreeze()

In [76]:
encoder = deepcopy(learn.model[0])

In [77]:
generator = deepcopy(learn.model) 

In [78]:
generator.load_state_dict(learn.model.state_dict())

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [79]:
class TextDicriminator(nn.Module):
    def __init__(self,encoder, nh):
        super().__init__()
        #encoder
        self.encoder = encoder
        #classifier
        layers = []
        layers+=bn_drop_lin(nh*3,nh,bias=False)
        layers += bn_drop_lin(nh,nh,p=0.25)
        layers+=bn_drop_lin(nh,1,p=0.15,actn=nn.Sigmoid())
        self.layers = nn.Sequential(*layers)
    
    def pool(self, x, bs, is_max):
        f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d
        return f(x.permute(0,2,1), (1,)).view(bs,-1)
    
    def forward(self, inp,y=None):
        raw_outputs, outputs = self.encoder(inp)
        output = outputs[-1]
        bs,sl,_ = output.size()
        avgpool = self.pool(output, bs, False)
        mxpool = self.pool(output, bs, True)
        x = torch.cat([output[:,-1], mxpool, avgpool], 1)
        out = self.layers(x)
        return out

In [80]:
disc = TextDicriminator(encoder,400)

In [81]:
def seq_gumbel_softmax(input):
    samples = []
    bs,sl,nc = input.size()
    for i in range(sl): 
        z = F.gumbel_softmax(input[:,i,:])
        samples.append(torch.multinomial(z,1))
    samples = torch.stack(samples).transpose(1,0).squeeze(2) 
    return samples

In [82]:
requires_grad(generator,True)
requires_grad(disc,True)

In [64]:
learn = TextGANLearner(data_lm,generator,disc,lm_loss,WassersteinLoss(),metrics=[accuracy])

In [83]:
learn.lr_find()

LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.


KeyboardInterrupt: 

In [66]:
learn.

> [0;32m/home/ubuntu/projects/fastai/lib/python3.6/site-packages/fastai/torch_core.py[0m(133)[0;36mrequires_grad[0;34m()[0m
[0;32m    131 [0;31m    [0;32mif[0m [0;32mnot[0m [0mps[0m[0;34m:[0m [0;32mreturn[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    132 [0;31m    [0;32mif[0m [0mb[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0;32mreturn[0m [0mps[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m.[0m[0mrequires_grad[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 133 [0;31m    [0;32mfor[0m [0mp[0m [0;32min[0m [0mps[0m[0;34m:[0m [0mp[0m[0;34m.[0m[0mrequires_grad[0m[0;34m=[0m[0mb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m    135 [0;31m[0;32mdef[0m [0mtrainable_params[0m[0;34m([0m[0mm[0m[0;34m:[0m[0mnn[0m[0;34m.[0m[0mModule[0m[0;34m)[0m[0;34m->[0m[0mParamList[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> u
> [0;32m<ipython-input-8-052cdeb210e0>[0m(13)[0;36m_set_tr

In [32]:
def reinforce_loss(input,sample,reward):
    loss=0
    bs,sl = sample.size()
    for i in range(sl):
        loss += -input[:,i,sample[:,i]] * reward
    return loss/sl

In [None]:
learn = TextGANLearner(data_lm,generator,disc,lm_loss,reinforce_loss,metrics=[accuracy])