# Adversarial autoencoders for text

### Code: https://github.com/shentianxiao/text-autoencoders

### Paper: https://arxiv.org/pdf/1905.12777.pdf



# Set up

In [1]:
import torch
from multiprocessing import cpu_count
print(cpu_count())
print(torch.cuda.is_available())

2
True


In [2]:
!git clone https://github.com/shentianxiao/text-autoencoders.git 

Cloning into 'text-autoencoders'...
remote: Enumerating objects: 114, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (27/27), done.[K
remote: Total 114 (delta 11), reused 12 (delta 4), pack-reused 83[K
Receiving objects: 100% (114/114), 270.78 KiB | 7.52 MiB/s, done.
Resolving deltas: 100% (56/56), done.


In [3]:
%cd text-autoencoders

/content/text-autoencoders


In [4]:
!bash download_data.sh

--2021-07-22 11:02:07--  http://people.csail.mit.edu/tianxiao/data/yelp.zip
Resolving people.csail.mit.edu (people.csail.mit.edu)... 128.30.2.133
Connecting to people.csail.mit.edu (people.csail.mit.edu)|128.30.2.133|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3676642 (3.5M) [application/zip]
Saving to: ‘yelp.zip’


2021-07-22 11:02:08 (4.97 MB/s) - ‘yelp.zip’ saved [3676642/3676642]

Archive:  yelp.zip
   creating: yelp/
   creating: yelp/tense/
  inflating: yelp/tense/valid.past   
  inflating: yelp/tense/valid.present  
  inflating: yelp/tense/test.past    
  inflating: yelp/tense/test.present  
   creating: yelp/sentiment/
  inflating: yelp/sentiment/100.neg  
  inflating: yelp/sentiment/100.pos  
  inflating: yelp/sentiment/1000.neg  
  inflating: yelp/sentiment/1000.pos  
  inflating: yelp/test.txt           
  inflating: yelp/train.txt          
  inflating: yelp/valid.txt          
   creating: yelp/interpolate/
  inflating: yelp/interpolate/example

# Imports

In [5]:
import os
import torch
import numpy as np
from sklearn.neighbors import NearestNeighbors

from vocab import Vocab
from model import *
from utils import *
from batchify import get_batches
from train import evaluate

# Getting the checkpoints of trained model from probml-data bucket 

In [6]:
from google.colab import auth
auth.authenticate_user()

In [7]:
import uuid
bucket_name = 'probml_data' 

In [8]:
!mkdir /content/text-autoencoders/checkpoints

In [9]:
!gsutil cp -r gs://{bucket_name}/text-autoencoders/vocab.txt /content/text-autoencoders/checkpoints/

Copying gs://probml_data/text-autoencoders/vocab.txt...
/ [1 files][100.7 KiB/100.7 KiB]                                                
Operation completed over 1 objects/100.7 KiB.                                    


In [10]:
!gsutil cp -r gs://{bucket_name}/text-autoencoders/text_ae_yelp_30_epochs.pt /content/text-autoencoders/checkpoints/

Copying gs://probml_data/text-autoencoders/text_ae_yelp_30_epochs.pt...
| [1 files][133.3 MiB/133.3 MiB]                                                
Operation completed over 1 objects/133.3 MiB.                                    


# Creating vocab

In [11]:
vocab = Vocab('/content/text-autoencoders/checkpoints/vocab.txt') #os.path.join(args.checkpoint, 'vocab.txt')

In [12]:
seed = 1111
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
batch_size = 100
max_len = 35

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loading checkpoints


In [14]:
ckpt = torch.load('/content/text-autoencoders/checkpoints/text_ae_yelp_30_epochs.pt')

In [15]:
train_args = ckpt['args']

# Selecting AAE model

In [16]:
model = {'dae': DAE, 'vae': VAE, 'aae': AAE}['aae'](vocab, train_args).to(device) 

In [17]:
model.load_state_dict(ckpt['model'])
model.flatten()
model.eval()

AAE(
  (embed): Embedding(10005, 512)
  (proj): Linear(in_features=1024, out_features=10005, bias=True)
  (drop): Dropout(p=0.5, inplace=False)
  (E): LSTM(512, 1024, bidirectional=True)
  (G): LSTM(512, 1024)
  (h2mu): Linear(in_features=2048, out_features=128, bias=True)
  (h2logvar): Linear(in_features=2048, out_features=128, bias=True)
  (z2emb): Linear(in_features=128, out_features=512, bias=True)
  (D): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=1, bias=True)
    (3): Sigmoid()
  )
)

In [18]:
def encode(sents, enc='mu'):
    assert enc == 'mu' or enc == 'z'
    batches, order = get_batches(sents, vocab, batch_size, device) 
    z = []
    for inputs, _ in batches:
        mu, logvar = model.encode(inputs)
        if enc == 'mu': 
            zi = mu
        else:
            zi = reparameterize(mu, logvar)
        z.append(zi.detach().cpu().numpy())
    z = np.concatenate(z, axis=0)
    z_ = np.zeros_like(z)
    z_[np.array(order)] = z
    return z_

In [19]:
def decode(z, dec='greedy'):
    sents = []
    i = 0
    while i < len(z):
        zi = torch.tensor(z[i: i+batch_size], device=device) 
        outputs = model.generate(zi, max_len, dec).t() 
        for s in outputs:
            sents.append([vocab.idx2word[id] for id in s[1:]])  
        i += batch_size 
    return strip_eos(sents)

# Reconstruction

In [20]:
sents = load_sent('/content/text-autoencoders/data/yelp/test.txt') 
z = encode(sents)
sents_rec = decode(z)
write_z(z, '/content/text-autoencoders/checkpoints/test.z') 
write_sent(sents_rec, '/content/text-autoencoders/checkpoints/test.rec') 

In [21]:
!head checkpoints/test.rec

my husband loves thin crust pizza .
sausage are great too
<unk> pizza is by far one of my favorite pizzas .
the little thin crust with bacon and <unk> it 's where its at .
i also like this location .
it has two levels which gives a little more warm heat .
and rather than the pizza was served missing a decent .
the reason ?
the `` oven '' pasta - '' - according to the waitress .
absolutely love their thin crust pizzas .


In [22]:
!head checkpoints/test.z

0.615650 0.843494 -0.963197 0.070233 0.056088 0.669214 0.485136 2.368523 0.772484 1.778468 1.293559 1.499551 0.275926 -1.152509 -1.007916 0.043573 -0.271024 0.050575 -1.482960 -3.467686 -0.270455 -1.402242 -1.689698 0.282012 0.010312 1.610025 0.975547 -0.883918 -1.522156 -1.117097 0.062196 0.786202 4.902931 0.293000 -0.692498 -1.617049 -2.322844 -1.172106 -1.931692 1.360384 0.921803 0.563292 0.891287 0.963349 -0.652590 -1.290544 0.416340 -2.521800 0.251845 1.603648 1.887917 -0.481518 -0.487605 -0.334191 -0.062563 2.082166 1.376685 1.007244 0.372941 -0.432422 -0.813552 3.452029 0.028994 0.205073 -0.597412 2.061502 0.077337 0.129536 1.283841 -0.504938 0.420815 -0.909272 0.402970 -0.220285 -0.952421 0.230303 0.129789 -0.360586 0.897388 0.122778 -0.466032 -1.223495 -0.565543 -0.287634 0.102804 2.894719 0.890665 -0.098187 -1.435500 -2.726444 -0.648847 -0.012419 0.455038 2.217391 -0.445422 0.073173 0.349082 0.153578 0.561829 -0.191955 0.505519 0.459280 0.497380 0.578776 1.259404 -0.433562 0.

# Sample

In [23]:
n = 10
dim_z = 128

In [24]:
z = np.random.normal(size=(n, dim_z)).astype('f')
sents = decode(z)
write_sent(sents, '/content/text-autoencoders/checkpoints/sample')

In [25]:
!head checkpoints/sample

special night with the food .
i ordered the very chorizo .
the pasta <unk> - three things : )
all are the best .
i had a reservation and received the service .
i <unk> will probably visit the city .
the <unk> is .
i think we were <unk> from <unk> with shipping .
we 've been to several people and to the service <unk> .
i wanted the eggs benedict and chili cheese strips .


# Tense

In [26]:
!head data/yelp/tense/valid.past

overpriced .
the pizza was pretty bland , despite a hefty helping of oregano .
perhaps because i was indoctrinated into loving this pizza as a small child .
we got the thin crust with steak and bacon and it was awesome !
i had to knock it down a star .
the food was n't good .
went here after a doctors appointment one day .
the pizza was good .
the small salad was n't their best .
the service sucked .


In [27]:
!head data/yelp/tense/valid.present

pizza is pretty good !
everything is usually pretty fresh and hot .
there is something about this pizza that is addictive .
there is very little sauce , fairly dry cheese , and a chewy crust .
i do n't know what to call it ... .
because every local i know loves monicals .
in fact , they even sell their french dressing online !
seriously , i 've looked . )
the pizza is pretty good , but not the best in champaign-urbana .
i prefer papa del 's .


In [28]:
!head data/yelp/tense/test.past

and more unbelievably the pizza served was missing a portion .
`` the oven ate it '' - according to the waitress .
reasonably priced .
could n't be more happy with their service either .
they could use a fork upgrade , but the rice was nice .
i got the garlic chicken with vegetables with lo mein .
paid over $ _ num _ and it was pretty greasy and just very underwhelming .
i walked in and briefly fingered coats , feeling lost and confused .
`` have you spent a winter here before ? ''
she asked , eyeing my confused expression .


In [29]:
k = 1

In [30]:
fa, fb, fc = '/content/text-autoencoders/data/yelp/tense/valid.past', '/content/text-autoencoders/data/yelp/tense/valid.present', '/content/text-autoencoders/data/yelp/tense/test.past'                     
sa, sb, sc = load_sent(fa), load_sent(fb), load_sent(fc)
za, zb, zc = encode(sa), encode(sb), encode(sc)
zd = zc + k * (zb.mean(axis=0) - za.mean(axis=0))
sd = decode(zd)
write_sent(sd, '/content/text-autoencoders/checkpoints/test.past2present') 

In [31]:
!head checkpoints/test.past2present

and rather than the pizza is served missing a decent .
`` the oven '' everything 's - '' according to the waitress .
reasonably priced .
you ca n't be more happy with their either .
they could use a text add , but the rice is nice .
i got the garlic chicken with vegetables with lo mein .
i paid over $ <unk> <unk> and it 's just pretty greasy and very bland .
i walked in and <unk> , book , and lost me look .
`` you have you a concert before here ? ''
she asked , <unk> my confused <unk> .


# Sentiment

In [32]:
!head data/yelp/sentiment/100.neg

the $ _num_ minimum charge to use a credit card is also annoying .
sorry but i do n't get the rave reviews for this place .
the desserts were very bland .
the cake portion was extremely light and a bit dry .
it was super dry and had a weird taste to the entire slice .
once again , i have n't figured out why they change so much .
consistently slow .
even in the awkward freezer burn then microwaved scent .
so nasty .
i hate mayonnaise .


In [33]:
!head data/yelp/sentiment/1000.neg

the answer was none .
this place sucks .
we have been there _num_ times , with an attempted 3rd time .
1st time , burnt pizza , was horrible .
wo n't go back .
also like all spinatos they are closed on monday .
the food was bland ... if you like spinatos , just avoid this location .
it 's not worth it .
friendly staff , but not overly helpful .
we got the thicker crust and it tasted really dry and hard .


k = 2

In [34]:
k = 2

In [35]:
fa, fb, fc = '/content/text-autoencoders/data/yelp/sentiment/100.neg', '/content/text-autoencoders/data/yelp/sentiment/100.pos', '/content/text-autoencoders/data/yelp/sentiment/1000.neg'                     
sa, sb, sc = load_sent(fa), load_sent(fb), load_sent(fc)
za, zb, zc = encode(sa), encode(sb), encode(sc)
zd = zc + k * (zb.mean(axis=0) - za.mean(axis=0))
sd = decode(zd)
write_sent(sd, '/content/text-autoencoders/checkpoints/2_1000.neg2pos') 

In [36]:
!head checkpoints/2_1000.neg2pos

the staff is both .
this place sucks .
we have been here for times , with family and time .
1st time , pizza , and was great .
wo n't go back .
also love all the <unk> are always on .
the food is ok , especially , <unk> , and not great location .
it 's not worth it .
friendly staff , but very helpful .
we got the seafood crust and it was really very easy .


k = 1.5

In [37]:
k = 1.5

In [38]:
fa, fb, fc = '/content/text-autoencoders/data/yelp/sentiment/100.neg', '/content/text-autoencoders/data/yelp/sentiment/100.pos', '/content/text-autoencoders/data/yelp/sentiment/1000.neg'                     
sa, sb, sc = load_sent(fa), load_sent(fb), load_sent(fc)
za, zb, zc = encode(sa), encode(sb), encode(sc)
zd = zc + k * (zb.mean(axis=0) - za.mean(axis=0))
sd = decode(zd)
write_sent(sd, '/content/text-autoencoders/checkpoints/1_5_1000.neg2pos') 

In [39]:
!head checkpoints/1_5_1000.neg2pos

the answer was none .
this place sucks .
we have been here for times , with an family time .
1st time , pizza , was great .
wo n't go back .
also love all the <unk> are closed on .
the food is alright , especially like <unk> , and not this location .
it 's not worth it .
friendly staff , but very helpful .
we got the unique crust and it was really fresh and hard .


k =1

In [40]:
k = 1

In [41]:
fa, fb, fc = '/content/text-autoencoders/data/yelp/sentiment/100.neg', '/content/text-autoencoders/data/yelp/sentiment/100.pos', '/content/text-autoencoders/data/yelp/sentiment/1000.neg'                     
sa, sb, sc = load_sent(fa), load_sent(fb), load_sent(fc)
za, zb, zc = encode(sa), encode(sb), encode(sc)
zd = zc + k * (zb.mean(axis=0) - za.mean(axis=0))
sd = decode(zd)
write_sent(sd, '/content/text-autoencoders/checkpoints/1_1000.neg2pos') 

In [42]:
!head checkpoints/1_1000.neg2pos

the answer was none .
this place sucks .
we have been here several times , with an husband each time .
1st time , bbq pizza , was horrible .
wo n't go back .
also like all <unk> are closed on monday .
the food was alright , if you like <unk> , just avoid this location .
it 's not worth it .
friendly staff , but not extremely helpful .
we got the thicker crust and it was really fresh and hard .


# Interpolation

In [43]:
n = 5

In [44]:
f1, f2 = '/content/text-autoencoders/data/yelp/interpolate/example.long', '/content/text-autoencoders/data/yelp/interpolate/example.short'        
s1, s2 = load_sent(f1), load_sent(f2)
z1, z2 = encode(s1), encode(s2)
zi = [interpolate(z1_, z2_, n) for z1_, z2_ in zip(z1, z2)]
zi = np.concatenate(zi, axis=0)
si = decode(zi)
si = list(zip(*[iter(si)]*(n)))
write_doc(si, '/content/text-autoencoders/checkpoints/example.int')                          

In [45]:
!head data/yelp/interpolate/example.long

i highly recommend it and i 'll definitely be back !
probably the worst chinese food i 've had in my life .
i say again , do n't ever stay at this hotel !
it 's so much better than the other chinese food places in this area .
the fried dumplings are a must if you ever visit this place .
definitely worth going for if you want good quick chinese food .
everyone who works there is very sweet and genuine too !
this was the must disgusting restaurant i have eaten at in years .
this is one of my favorite restaurants in town .
literally the best place for students and locals who want diversity in their food .


In [46]:
!head data/yelp/interpolate/example.short

i will be back !
worst chinese food .
do n't stay !
better than other places .
fried dumplings are a must .
definitely worth going !
everyone is sweet !
disgusting !
my favorite !
the best !


In [47]:
!head checkpoints/example.int

i highly recommend it and i 'll definitely be back !
i highly recommend it and i 'll definitely be back !
i will definitely be back !
i will be back !
i will be back !

probably the worst chinese food i 've had in my life .
probably the worst chinese food i 've had in my life .
worst chinese food .
worst chinese food .
