# Adversarial autoencoders for text

### Code: https://github.com/shentianxiao/text-autoencoders

### Paper: https://arxiv.org/pdf/1905.12777.pdf

### GCP account creation: https://cloud.google.com/apigee/docs/hybrid/v1.1/precog-gcpaccount



# Set up

In [1]:
import torch
from multiprocessing import cpu_count

print(cpu_count())
print(torch.cuda.is_available())

2
True


In [2]:
!git clone https://github.com/shentianxiao/text-autoencoders.git

Cloning into 'text-autoencoders'...
remote: Enumerating objects: 114, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (27/27), done.[K
remote: Total 114 (delta 11), reused 12 (delta 4), pack-reused 83[K
Receiving objects: 100% (114/114), 270.78 KiB | 22.56 MiB/s, done.
Resolving deltas: 100% (56/56), done.


In [3]:
%cd text-autoencoders

/content/text-autoencoders


In [4]:
!bash download_data.sh

--2021-07-25 06:04:40--  http://people.csail.mit.edu/tianxiao/data/yelp.zip
Resolving people.csail.mit.edu (people.csail.mit.edu)... 128.30.2.133
Connecting to people.csail.mit.edu (people.csail.mit.edu)|128.30.2.133|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3676642 (3.5M) [application/zip]
Saving to: ‘yelp.zip’


2021-07-25 06:04:40 (7.49 MB/s) - ‘yelp.zip’ saved [3676642/3676642]

Archive:  yelp.zip
   creating: yelp/
   creating: yelp/tense/
  inflating: yelp/tense/valid.past   
  inflating: yelp/tense/valid.present  
  inflating: yelp/tense/test.past    
  inflating: yelp/tense/test.present  
   creating: yelp/sentiment/
  inflating: yelp/sentiment/100.neg  
  inflating: yelp/sentiment/100.pos  
  inflating: yelp/sentiment/1000.neg  
  inflating: yelp/sentiment/1000.pos  
  inflating: yelp/test.txt           
  inflating: yelp/train.txt          
  inflating: yelp/valid.txt          
   creating: yelp/interpolate/
  inflating: yelp/interpolate/example

# Imports

In [5]:
import os
import torch
import numpy as np
from sklearn.neighbors import NearestNeighbors

from vocab import Vocab
from model import *
from utils import *
from batchify import get_batches
from train import evaluate

# Getting the checkpoints of trained model from probml-data bucket 

In [6]:
# Authentication is required to access a protected bucket. This is not required for a public one.
from google.colab import auth

auth.authenticate_user()

In [7]:
bucket_name = "probml_data"

In [8]:
!mkdir /content/text-autoencoders/checkpoints

How to use [gsutil](https://cloud.google.com/storage/docs/gsutil/commands/help)  

In [9]:
!gsutil cp -r gs://{bucket_name}/text-autoencoders/vocab.txt /content/text-autoencoders/checkpoints/

Copying gs://probml_data/text-autoencoders/vocab.txt...
/ [0 files][    0.0 B/100.7 KiB]                                                / [1 files][100.7 KiB/100.7 KiB]                                                
Operation completed over 1 objects/100.7 KiB.                                    


In [10]:
!gsutil cp -r gs://{bucket_name}/text-autoencoders/text_ae_yelp_30_epochs.pt /content/text-autoencoders/checkpoints/

Copying gs://probml_data/text-autoencoders/text_ae_yelp_30_epochs.pt...
\ [1 files][133.3 MiB/133.3 MiB]                                                
Operation completed over 1 objects/133.3 MiB.                                    


# Creating vocab

In [11]:
vocab = Vocab("/content/text-autoencoders/checkpoints/vocab.txt")  # os.path.join(args.checkpoint, 'vocab.txt')

In [12]:
seed = 1111
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
batch_size = 100
max_len = 35

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loading checkpoints


In [14]:
ckpt = torch.load("/content/text-autoencoders/checkpoints/text_ae_yelp_30_epochs.pt")

In [15]:
train_args = ckpt["args"]

# Selecting AAE model

In [16]:
model = {"dae": DAE, "vae": VAE, "aae": AAE}["aae"](vocab, train_args).to(device)

In [17]:
model.load_state_dict(ckpt["model"])
model.flatten()
model.eval()

AAE(
  (embed): Embedding(10005, 512)
  (proj): Linear(in_features=1024, out_features=10005, bias=True)
  (drop): Dropout(p=0.5, inplace=False)
  (E): LSTM(512, 1024, bidirectional=True)
  (G): LSTM(512, 1024)
  (h2mu): Linear(in_features=2048, out_features=128, bias=True)
  (h2logvar): Linear(in_features=2048, out_features=128, bias=True)
  (z2emb): Linear(in_features=128, out_features=512, bias=True)
  (D): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=1, bias=True)
    (3): Sigmoid()
  )
)

In [18]:
def encode(sents, enc="mu"):
    assert enc == "mu" or enc == "z"
    batches, order = get_batches(sents, vocab, batch_size, device)
    z = []
    for inputs, _ in batches:
        mu, logvar = model.encode(inputs)
        if enc == "mu":
            zi = mu
        else:
            zi = reparameterize(mu, logvar)
        z.append(zi.detach().cpu().numpy())
    z = np.concatenate(z, axis=0)
    z_ = np.zeros_like(z)
    z_[np.array(order)] = z
    return z_

In [19]:
def decode(z, dec="greedy"):
    sents = []
    i = 0
    while i < len(z):
        zi = torch.tensor(z[i : i + batch_size], device=device)
        outputs = model.generate(zi, max_len, dec).t()
        for s in outputs:
            sents.append([vocab.idx2word[id] for id in s[1:]])
        i += batch_size
    return strip_eos(sents)

# Reconstruction

In [20]:
n = 5
sents = load_sent("/content/text-autoencoders/data/yelp/test.txt")
z = encode(sents)
sents_rec = decode(z)
write_z(z, "/content/text-autoencoders/checkpoints/test.z")
write_sent(sents_rec, "/content/text-autoencoders/checkpoints/test.rec")

In [21]:
for i in range(n):
    sentence = ""
    rec = ""
    for word in sents[i]:

        sentence = sentence + word + " "
    print("Original sentence: " + sentence)

    for word in sents_rec[i]:
        rec = rec + word + " "
    print("Reconstructed sentence: " + rec)
    print("\n")

Original sentence: husband loves the thin crust pizza . 
Reconstructed sentence: my husband loves thin crust pizza . 


Original sentence: breadsticks are great too 
Reconstructed sentence: sausage are great too 


Original sentence: monicals pizza is by far one of my favorite pizzas . 
Reconstructed sentence: <unk> pizza is by far one of my favorite pizzas . 


Original sentence: the traditional thin crust topped with sausage and pinnaple it where its at . 
Reconstructed sentence: the little thin crust with bacon and <unk> it 's where its at . 


Original sentence: i also like this location . 
Reconstructed sentence: i also like this location . 




# Sample

In [22]:
n = 10
dim_z = 128

In [23]:
z = np.random.normal(size=(n, dim_z)).astype("f")
sents = decode(z)
write_sent(sents, "/content/text-autoencoders/checkpoints/sample")

In [24]:
for i in range(n):
    sample = ""
    for word in sents[i]:
        sample = sample + word + " "
    print(sample)

special night with the food . 
i ordered the very chorizo . 
the pasta <unk> - three things : ) 
all are the best . 
i had a reservation and received the service . 
i <unk> will probably visit the city . 
the <unk> is . 
i think we were <unk> from <unk> with shipping . 
we 've been to several people and to the service <unk> . 
i wanted the eggs benedict and chili cheese strips . 


# Tense

In [25]:
n = 10

In [26]:
k = 1

In [27]:
fa, fb, fc = (
    "/content/text-autoencoders/data/yelp/tense/valid.past",
    "/content/text-autoencoders/data/yelp/tense/valid.present",
    "/content/text-autoencoders/data/yelp/tense/test.past",
)
sa, sb, sc = load_sent(fa), load_sent(fb), load_sent(fc)
za, zb, zc = encode(sa), encode(sb), encode(sc)
zd = zc + k * (zb.mean(axis=0) - za.mean(axis=0))
sd = decode(zd)
write_sent(sd, "/content/text-autoencoders/checkpoints/test.past2present")

In [28]:
for i in range(n):
    tense = ""
    for word in sd[i]:
        tense = tense + word + " "
    print(tense)

and rather than the pizza is served missing a decent . 
`` the oven '' everything 's - '' according to the waitress . 
reasonably priced . 
you ca n't be more happy with their either . 
they could use a text add , but the rice is nice . 
i got the garlic chicken with vegetables with lo mein . 
i paid over $ <unk> <unk> and it 's just pretty greasy and very bland . 
i walked in and <unk> , book , and lost me look . 
`` you have you a concert before here ? '' 
she asked , <unk> my confused <unk> . 


# Sentiment

In [29]:
n = 10

k = 2

In [30]:
k = 2

In [31]:
fa, fb, fc = (
    "/content/text-autoencoders/data/yelp/sentiment/100.neg",
    "/content/text-autoencoders/data/yelp/sentiment/100.pos",
    "/content/text-autoencoders/data/yelp/sentiment/1000.neg",
)
sa, sb, sc = load_sent(fa), load_sent(fb), load_sent(fc)
za, zb, zc = encode(sa), encode(sb), encode(sc)
zd = zc + k * (zb.mean(axis=0) - za.mean(axis=0))
sd = decode(zd)
write_sent(sd, "/content/text-autoencoders/checkpoints/2_1000.neg2pos")

In [32]:
for i in range(n):
    sentiment = ""
    for word in sd[i]:
        sentiment = sentiment + word + " "
    print(sentiment)

the staff is both . 
this place sucks . 
we have been here for times , with family and time . 
1st time , pizza , and was great . 
wo n't go back . 
also love all the <unk> are always on . 
the food is ok , especially , <unk> , and not great location . 
it 's not worth it . 
friendly staff , but very helpful . 
we got the seafood crust and it was really very easy . 


k = 1.5

In [33]:
k = 1.5

In [34]:
fa, fb, fc = (
    "/content/text-autoencoders/data/yelp/sentiment/100.neg",
    "/content/text-autoencoders/data/yelp/sentiment/100.pos",
    "/content/text-autoencoders/data/yelp/sentiment/1000.neg",
)
sa, sb, sc = load_sent(fa), load_sent(fb), load_sent(fc)
za, zb, zc = encode(sa), encode(sb), encode(sc)
zd = zc + k * (zb.mean(axis=0) - za.mean(axis=0))
sd = decode(zd)
write_sent(sd, "/content/text-autoencoders/checkpoints/1_5_1000.neg2pos")

In [35]:
for i in range(n):
    sentiment = ""
    for word in sd[i]:
        sentiment = sentiment + word + " "
    print(sentiment)

the answer was none . 
this place sucks . 
we have been here for times , with an family time . 
1st time , pizza , was great . 
wo n't go back . 
also love all the <unk> are closed on . 
the food is alright , especially like <unk> , and not this location . 
it 's not worth it . 
friendly staff , but very helpful . 
we got the unique crust and it was really fresh and hard . 


k =1

In [36]:
k = 1

In [37]:
fa, fb, fc = (
    "/content/text-autoencoders/data/yelp/sentiment/100.neg",
    "/content/text-autoencoders/data/yelp/sentiment/100.pos",
    "/content/text-autoencoders/data/yelp/sentiment/1000.neg",
)
sa, sb, sc = load_sent(fa), load_sent(fb), load_sent(fc)
za, zb, zc = encode(sa), encode(sb), encode(sc)
zd = zc + k * (zb.mean(axis=0) - za.mean(axis=0))
sd = decode(zd)
write_sent(sd, "/content/text-autoencoders/checkpoints/1_1000.neg2pos")

In [38]:
for i in range(n):
    sentiment = ""
    for word in sd[i]:
        sentiment = sentiment + word + " "
    print(sentiment)

the answer was none . 
this place sucks . 
we have been here several times , with an husband each time . 
1st time , bbq pizza , was horrible . 
wo n't go back . 
also like all <unk> are closed on monday . 
the food was alright , if you like <unk> , just avoid this location . 
it 's not worth it . 
friendly staff , but not extremely helpful . 
we got the thicker crust and it was really fresh and hard . 


# Interpolation

In [39]:
f1, f2 = (
    "/content/text-autoencoders/data/yelp/interpolate/example.long",
    "/content/text-autoencoders/data/yelp/interpolate/example.short",
)
s1, s2 = load_sent(f1), load_sent(f2)
z1, z2 = encode(s1), encode(s2)
zi = [interpolate(z1_, z2_, n) for z1_, z2_ in zip(z1, z2)]
zi = np.concatenate(zi, axis=0)
si = decode(zi)
write_doc(si, "/content/text-autoencoders/checkpoints/example.int")

In [40]:
n = 10

In [41]:
for i in range(n):
    interpolation = ""
    for word in si[i]:
        interpolation = interpolation + word + " "
    print(interpolation)

i highly recommend it and i 'll definitely be back ! 
i highly recommend it and i 'll definitely be back ! 
i highly recommend it and i 'll definitely be back ! 
i highly recommend it and i 'll definitely be back ! 
i would absolutely recommend it will be back ! 
i will definitely be back ! 
i will definitely be back ! 
i will be back ! 
i will be back ! 
i will be back ! 
