In [5]:
from nltk.corpus import brown
import operator

def get_sentences():
  # returns 57340 of the Brown corpus
  # each sentence is represented as a list of individual string tokens
  return brown.sents()


def get_sentences_with_word2idx():
  sentences = get_sentences()
  indexed_sentences = []

  i = 2
  word2idx = {'START': 0, 'END': 1}
  for sentence in sentences:
    indexed_sentence = []
    for token in sentence:
      token = token.lower()
      if token not in word2idx:
        word2idx[token] = i
        i += 1

      indexed_sentence.append(word2idx[token])
    indexed_sentences.append(indexed_sentence)

  print("Vocab size:", i)
  return indexed_sentences, word2idx


def get_sentences_with_word2idx_limit_vocab(n_vocab=2000, keep_words=KEEP_WORDS):
  sentences = get_sentences()
  indexed_sentences = []

  i = 2
  word2idx = {'START': 0, 'END': 1}
  idx2word = ['START', 'END']

  word_idx_count = {
    0: float('inf'),
    1: float('inf'),
  }

  for sentence in sentences:
    indexed_sentence = []
    for token in sentence:
      token = token.lower()
      if token not in word2idx:
        idx2word.append(token)
        word2idx[token] = i
        i += 1

      # keep track of counts for later sorting
      idx = word2idx[token]
      word_idx_count[idx] = word_idx_count.get(idx, 0) + 1

      indexed_sentence.append(idx)
    indexed_sentences.append(indexed_sentence)



  # restrict vocab size

  # set all the words I want to keep to infinity
  # so that they are included when I pick the most
  # common words
  for word in keep_words:
    word_idx_count[word2idx[word]] = float('inf')

  sorted_word_idx_count = sorted(word_idx_count.items(), key=operator.itemgetter(1), reverse=True)
  word2idx_small = {}
  new_idx = 0
  idx_new_idx_map = {}
  for idx, count in sorted_word_idx_count[:n_vocab]:
    word = idx2word[idx]
    print(word, count)
    word2idx_small[word] = new_idx
    idx_new_idx_map[idx] = new_idx
    new_idx += 1
  # let 'unknown' be the last token
  word2idx_small['UNKNOWN'] = new_idx 
  unknown = new_idx

  assert('START' in word2idx_small)
  assert('END' in word2idx_small)
  for word in keep_words:
    assert(word in word2idx_small)

  # map old idx to new idx
  sentences_small = []
  for sentence in indexed_sentences:
    if len(sentence) > 1:
      new_sentence = [idx_new_idx_map[idx] if idx in idx_new_idx_map else unknown for idx in sentence]
      sentences_small.append(new_sentence)

  return sentences_small, word2idx_small

In [2]:
!pip install nltk

Collecting nltk
  Downloading nltk-3.5.zip (1.4 MB)
Collecting click
  Downloading click-7.1.2-py2.py3-none-any.whl (82 kB)
Collecting joblib
  Downloading joblib-0.16.0-py3-none-any.whl (300 kB)
Collecting regex
  Downloading regex-2020.7.14-cp37-cp37m-win_amd64.whl (268 kB)
Collecting tqdm
  Downloading tqdm-4.48.2-py2.py3-none-any.whl (68 kB)
Using legacy setup.py install for nltk, since package 'wheel' is not installed.
Installing collected packages: click, joblib, regex, tqdm, nltk
    Running setup.py install for nltk: started
    Running setup.py install for nltk: finished with status 'done'
Successfully installed click-7.1.2 joblib-0.16.0 nltk-3.5 regex-2020.7.14 tqdm-4.48.2

You should consider upgrading via the 'c:\users\mazic\appdata\local\programs\python\python37\python.exe -m pip install --upgrade pip' command.





In [10]:
import numpy as np
import random
import nltk

nltk.download('brown')

sentences, word2idx = get_sentences_with_word2idx_limit_vocab(2000)

V = len(word2idx)
print("Vocab size:", V)

start_idx = word2idx['START']
end_idx = word2idx['END']

D = 100
W1 = np.random.randn(V, D) / np.sqrt(V)
W2 = np.random.randn(D, V) / np.sqrt(D)

losses = []
epochs = 1
lr = 1e-2

def softmax(a):
  a = a - a.max()
  exp_a = np.exp(a)
  return exp_a / exp_a.sum(axis=1, keepdims=True)

for epoch in range(epochs):
  random.shuffle(sentences)

  j = 0
  for sentence in sentences:
    sentence = [start_idx] + sentence + [end_idx]
    n = len(sentence)
    inputs = sentence[:n-1]
    targets = sentence[1:]

    hidden = np.tanh(W1[inputs])
    predictions = softmax(hidden.dot(W2))

    loss = -np.sum(np.log(predictions[np.arange(n - 1), targets])) / (n - 1)
    losses.append(loss)

    doutput = predictions # N x V
    doutput[np.arange(n - 1), targets] -= 1
    W2 = W2 - lr * hidden.T.dot(doutput) # (D x N) (N x V)
    dhidden = doutput.dot(W2.T) * (1 - hidden * hidden) # (N x V) (V x D) * (N x D)

    # fastest way
    np.subtract.at(W1, inputs, lr * dhidden)


    if j % 100 == 0:
      print("epoch:", epoch, "sentence: %s/%s" % (j, len(sentences)), "loss:", loss)
    j += 1

  




[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\mazic\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!


START inf
END inf
man inf
paris inf
britain inf
england inf
king inf
woman inf
rome inf
london inf
queen inf
italy inf
france inf
the 69971
, 58334
. 49346
of 36412
and 28853
to 26158
a 23195
in 21337
that 10594
is 10109
was 9815
he 9548
for 9489
`` 8837
'' 8789
it 8760
with 7289
as 7253
his 6996
on 6741
be 6377
; 5566
at 5372
by 5306
i 5164
this 5145
had 5133
? 4693
not 4610
are 4394
but 4381
from 4370
or 4206
have 3942
an 3740
they 3620
which 3561
-- 3432
one 3292
you 3286
were 3284
her 3036
all 3001
she 2860
there 2728
would 2714
their 2669
we 2652
him 2619
been 2472
) 2466
has 2437
( 2435
when 2331
who 2252
will 2245
more 2215
if 2198
no 2139
out 2097
so 1985
said 1961
what 1908
up 1890
its 1858
about 1815
: 1795
into 1791
than 1790
them 1788
can 1772
only 1748
other 1702
new 1635
some 1618
could 1601
time 1598
! 1596
these 1573
two 1412
may 1402
then 1380
do 1363
first 1361
any 1344
my 1318
now 1314
such 1303
like 1292
our 1252
over 1236
me 1181
even 1170
most 1159
made 1125
also 

obvious 92
fell 92
thin 92
pieces 92
management 91
1958 91
measure 91
parents 91
security 91
base 91
entirely 91
civil 91
frequently 91
records 91
structure 91
dinner 91
weight 91
condition 91
mike 91
objective 91
complex 91
produced 90
noted 90
caused 90
equal 90
balance 90
you'll 90
purposes 90
corporation 90
dance 90
kitchen 90
failure 89
pass 89
goes 89
names 89
quickly 89
regard 89
published 89
famous 89
develop 89
clothes 89
laws 88
announced 88
carry 88
cover 88
moreover 88
add 88
greatest 88
check 88
enemy 88
leaving 88
key 88
manager 88
doesn't 88
active 88
break 88
bottom 88
pain 88
relationship 88
sources 88
poetry 88
assistance 87
operating 87
battle 87
companies 87
fixed 87
possibility 87
mary 87
product 87
spoke 87
units 87
touch 87
bright 87
finished 87
carefully 87
facts 87
previous 86
citizens 86
takes 86
e. 86
allowed 86
require 86
workers 86
build 86
patient 86
financial 86
philosophy 86
loss 86
rose 86
died 86
scientific 86
otherwise 86
inches 86
significant 86
seei

Vocab size: 2001
epoch: 0 sentence: 0/57013 loss: 7.607585883952575
epoch: 0 sentence: 100/57013 loss: 7.421682206133521
epoch: 0 sentence: 200/57013 loss: 6.55405851290688
epoch: 0 sentence: 300/57013 loss: 5.718468832227313
epoch: 0 sentence: 400/57013 loss: 6.61468923116405
epoch: 0 sentence: 500/57013 loss: 6.977573529206938
epoch: 0 sentence: 600/57013 loss: 6.346833997603713
epoch: 0 sentence: 700/57013 loss: 5.026065147063738
epoch: 0 sentence: 800/57013 loss: 5.6356320160361895
epoch: 0 sentence: 900/57013 loss: 6.0240260808371024
epoch: 0 sentence: 1000/57013 loss: 5.855175634893521
epoch: 0 sentence: 1100/57013 loss: 5.221749189129582
epoch: 0 sentence: 1200/57013 loss: 4.356356524631772
epoch: 0 sentence: 1300/57013 loss: 4.670681432850839
epoch: 0 sentence: 1400/57013 loss: 5.446488322650342
epoch: 0 sentence: 1500/57013 loss: 4.413810556188633
epoch: 0 sentence: 1600/57013 loss: 4.209868930008858
epoch: 0 sentence: 1700/57013 loss: 5.3853364090601765
epoch: 0 sentence: 180

epoch: 0 sentence: 15100/57013 loss: 5.128822005448149
epoch: 0 sentence: 15200/57013 loss: 4.291289708674629
epoch: 0 sentence: 15300/57013 loss: 4.22165418951235
epoch: 0 sentence: 15400/57013 loss: 5.358611173171433
epoch: 0 sentence: 15500/57013 loss: 2.946408446925263
epoch: 0 sentence: 15600/57013 loss: 4.243136885624621
epoch: 0 sentence: 15700/57013 loss: 2.9434006626295273
epoch: 0 sentence: 15800/57013 loss: 3.377482647567012
epoch: 0 sentence: 15900/57013 loss: 4.041922552672007
epoch: 0 sentence: 16000/57013 loss: 3.8207246681298916
epoch: 0 sentence: 16100/57013 loss: 3.715987277450021
epoch: 0 sentence: 16200/57013 loss: 3.700917964070688
epoch: 0 sentence: 16300/57013 loss: 5.094875751313404
epoch: 0 sentence: 16400/57013 loss: 4.395947573663907
epoch: 0 sentence: 16500/57013 loss: 4.245476319538039
epoch: 0 sentence: 16600/57013 loss: 3.6413029044952108
epoch: 0 sentence: 16700/57013 loss: 3.9232487177369437
epoch: 0 sentence: 16800/57013 loss: 4.262398952470124
epoch: 

epoch: 0 sentence: 30000/57013 loss: 2.5780241009949365
epoch: 0 sentence: 30100/57013 loss: 3.9430467547600294
epoch: 0 sentence: 30200/57013 loss: 4.604180004797428
epoch: 0 sentence: 30300/57013 loss: 5.46602018045493
epoch: 0 sentence: 30400/57013 loss: 4.4372519307246385
epoch: 0 sentence: 30500/57013 loss: 3.7623274645952747
epoch: 0 sentence: 30600/57013 loss: 4.641404479959296
epoch: 0 sentence: 30700/57013 loss: 3.9417693674508616
epoch: 0 sentence: 30800/57013 loss: 6.687169612573779
epoch: 0 sentence: 30900/57013 loss: 3.6220507642352358
epoch: 0 sentence: 31000/57013 loss: 5.3769306888528945
epoch: 0 sentence: 31100/57013 loss: 4.651662044981995
epoch: 0 sentence: 31200/57013 loss: 4.624426989565277
epoch: 0 sentence: 31300/57013 loss: 4.067982240163939
epoch: 0 sentence: 31400/57013 loss: 4.731312994982065
epoch: 0 sentence: 31500/57013 loss: 5.290893941964855
epoch: 0 sentence: 31600/57013 loss: 3.9650480017379826
epoch: 0 sentence: 31700/57013 loss: 3.480855631670552
epo

epoch: 0 sentence: 44900/57013 loss: 4.513638778153842
epoch: 0 sentence: 45000/57013 loss: 3.6269400466617094
epoch: 0 sentence: 45100/57013 loss: 4.9937566818815124
epoch: 0 sentence: 45200/57013 loss: 4.155280701334273
epoch: 0 sentence: 45300/57013 loss: 3.5575655156980717
epoch: 0 sentence: 45400/57013 loss: 5.395948306860619
epoch: 0 sentence: 45500/57013 loss: 4.124477592662484
epoch: 0 sentence: 45600/57013 loss: 2.4584315736220406
epoch: 0 sentence: 45700/57013 loss: 3.758911235512949
epoch: 0 sentence: 45800/57013 loss: 4.231752712323322
epoch: 0 sentence: 45900/57013 loss: 4.5044567826234845
epoch: 0 sentence: 46000/57013 loss: 4.0836249370544975
epoch: 0 sentence: 46100/57013 loss: 3.0828296023218367
epoch: 0 sentence: 46200/57013 loss: 7.436672913094055
epoch: 0 sentence: 46300/57013 loss: 4.333061219012438
epoch: 0 sentence: 46400/57013 loss: 2.88787071236635
epoch: 0 sentence: 46500/57013 loss: 4.320299644901593
epoch: 0 sentence: 46600/57013 loss: 4.014624465245131
epoc