# In this notebook
This notebook handles validating the results of the VAE to ensure that the reconstruction *generally* looks good.

In [1]:
import sys
import torch
import joblib
sys.path.append("../scripts/")
import vae

In [3]:
model = torch.load('../scripts/model/model_3epoch.pt')

In [4]:
cv = joblib.load('../data/san_francisco/cached/count_vec.joblib')

In [5]:
tweets = vae.Tweets('../data/san_francisco/')
x_test = tweets.load(test=True)

Loading in the data...


  tweets = self._load_data()


Cached file was found...loading lemmatized tweets from the cache.
Creating the count vector


In [19]:
model = vae.VAE(tweets.vocab_size)

In [20]:
model.load_state_dict(torch.load('../scripts/model/model_3epoch.pt'))

<All keys matched successfully>

In [128]:
model.eval()

test_doc = x_test[0] 
recon = model(test_doc)
s, W, mu, logvar = recon

In [129]:
recon = s @ W

In [130]:
mult_params = recon/recon.sum()

In [136]:
recon_vals = mult_params.detach().numpy()[0]

In [56]:
words = cv.get_feature_names_out()

In [118]:
# Make sure sensible words are in our corpus
words_to_check = [
    'cough',
    'lung',
    'eye',
    'itch',
    'fire',
    'wildfire',
    'smoke',
    'hurt'
]

for w in words_to_check:
    
    print(w, ': ', w in words)

cough :  True
lung :  True
eye :  True
itch :  True
fire :  False
wildfire :  True
smoke :  True
hurt :  True


In [137]:
# What are the most probable words in this reconstruction?
for i, v in enumerate(reversed(np.argsort(recon_vals))):
    print(words[v])
    if i > 20:
        break

mexican
keeper
jeopardy
attack
florida
spanish
dem
cater
mullen
tip
mute
dreamforce
sacrifice
max
hacker
hitter
specific
lego
module
sibling
contrary
transportation


In [123]:
# What are the most probable words in our generated document?
for i, v in enumerate(reversed(np.argsort(test_doc))):
    print(words[v])
    if i > 20:
        break

just
good
know
like
thank
love
make
people
game
want
say
right
great
look
today
way
think
hit
don
work
time
day


In [138]:
# Least probable words in the reconstruction
for i, v in enumerate(np.argsort(recon_vals)):
    print(words[v])
    if i > 20:
        break

japanese
appt
francis
doom
oil
cowardly
misogynistic
patent
employment
dough
fraud
fit
outing
outfit
wood
quite
omarosa
hahaha
rewatch
title
prob
surgeon


In [125]:
# What are the lest probable words in our generated document?
for i, v in enumerate(np.argsort(test_doc)):
    print(words[v])
    if i > 20:
        break

aa
organic
organ
org
oreo
oregon
ordinary
orbit
oracle
optional
optimize
optimistic
optimism
opt
oppression
oppress
opposition
oppose
opportunity
opponent
opioid
opinion


In [111]:
# Min probability of a word being in this corpus with min_df=100
print("Min: ", 1000/(1.8*10**6))

# Max prob of a word being in this corpus with max_df=0.1
print("Max: ", 0.1)


Min:  0.0005555555555555556
Max:  0.1


In [144]:
mu

tensor([[-15.7212, -34.9691, -36.0413, -35.4208, -30.5385, -33.4864, -35.0444,
         -33.2605, -37.6280, -35.3061, -32.4391, -37.4034, -37.1712, -30.5614,
         -37.9732, -32.9141, -35.1267, -34.4124, -35.7661,   0.3856]],
       grad_fn=<MmBackward0>)

In [148]:
torch.exp(logvar)

tensor([[2.2437e-06, 2.2084e-14, 3.6074e-16, 3.9704e-15, 1.1269e-12, 9.9991e-15,
         7.4339e-16, 2.9564e-15, 3.2217e-14, 1.4587e-15, 2.1465e-15, 1.1409e-14,
         3.5717e-15, 4.5019e-15, 7.6575e-15, 2.9934e-14, 3.7548e-16, 1.6810e-15,
         3.8571e-15, 2.8558e-14]], grad_fn=<ExpBackward0>)

## Conclusion: 

Looks like this specific model is probably doing OK, but some words are perhaps too infrequent or too frequent.

We should look at words that only appear in at least every 20 or 50 tweets, but not more than 1000, so that they have a probability of showing up enough, but not too much.

Also, it looks like we are getting negative mu's with small variances. This doesn't make sense. We need to constrain our mus to be positive, with small variances.

TODO:
 - [ ] Change the count vector parameters.
 - [ ] Review the KL divergence...shouldn't these MUs be positive?
 - [ ] Add more components to the topic matrix. 20 seems too small...how about 50?