Skip to content
This repository has been archived by the owner on Aug 18, 2021. It is now read-only.

Commit

Permalink
Update glove-word-vectors example for most recent version of torchtext
Browse files Browse the repository at this point in the history
  • Loading branch information
jemgold committed Sep 14, 2017
1 parent 5758a4d commit 7136ea5
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions glove-word-vectors/glove-word-vectors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,52 +48,59 @@
"source": [
"## Loading word vectors\n",
"\n",
"The `load_word_vectors` function will download and unpack the word vectors directly from http://nlp.stanford.edu/data/"
"Torchtext includes functions to download GloVe (and other) embeddings"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchtext.vocab as vocab"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loading word vectors from ./glove.6B.100d.pt\n",
"Loaded 400000 words\n"
]
}
],
"source": [
"import torch\n",
"from torchtext.vocab import load_word_vectors\n",
"glove = vocab.GloVe(name='6B', dim=100)\n",
"\n",
"wv_dict, wv_arr, wv_size = load_word_vectors('.', 'glove.6B', 100)\n",
"\n",
"print('Loaded', len(wv_arr), 'words')"
"print('Loaded {} words'.format(len(glove.itos)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`load_word_vectors` returns a dictionary of words to indexes, and an array of actual vectors. To get a word vector get the index to get the vector:"
"The returned `GloVe` object includes attributes:\n",
"- `stoi` _string-to-index_ returns a dictionary of words to indexes\n",
"- `itos` _index-to-string_ returns an array of words by index\n",
"- `vectors` returns the actual vectors. To get a word vector get the index to get the vector:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"def get_word(word):\n",
" return wv_arr[wv_dict[word]]"
" return glove.vectors[glove.stoi[word]]"
]
},
{
Expand All @@ -109,14 +116,17 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def closest(d, n=10):\n",
" all_dists = [(w, torch.dist(d, get_word(w))) for w in wv_dict]\n",
"def closest(vec, n=10):\n",
" \"\"\"\n",
" Find the closest words for a given vector\n",
" \"\"\"\n",
" all_dists = [(w, torch.dist(vec, get_word(w))) for w in glove.itos]\n",
" return sorted(all_dists, key=lambda t: t[1])[:n]"
]
},
Expand All @@ -129,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {
"collapsed": true
},
Expand All @@ -149,10 +159,8 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
Expand Down

0 comments on commit 7136ea5

Please sign in to comment.