## Devise

What if we could get a set of word and images to be in the same space

```
beagle dog input --> model A --> jumbo jet
beagle dog input --> model B --> corgie
```
Consider models A and B. In traditional terms, both of these models are wrong (have the same score). But in word vector space, corgie (a dog) is much closer to beagle, so model B is much better than model A

**idea** - train a model that finds a word vector for the word you want. Instead of class.

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

### Import our libraries

In [5]:
import sys
sys.path.append('../')
from fastai.conv_learner import *
torch.backends.cudnn.benchmark=True

import fastText as ft
import torchvision.transforms as transforms

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

tfms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

### Setup our paths

In [6]:
PATH = Path('data/imagenet/')
TMP_PATH = PATH/'tmp'
TRANS_PATH = Path('data/translate/')
PATH_TRN = PATH/'train'
fname = 'valid/n01440764/ILSVRC2012_val_00007197.JPEG'

### Load the Word Vectors

In [10]:
ft_vecs = ft.load_model(str((TRANS_PATH/'wiki.en.bin')))

In [13]:
ft_vecs.get_word_vector('king')[:10]

array([ 0.03259, -0.18164, -0.29049, -0.10506, -0.16712, -0.07748, -0.5661 , -0.08622, -0.00216,  0.15366],
      dtype=float32)

In [14]:
ft_words = ft_vecs.get_words(include_freq=True)
ft_word_dict = {k:v for k,v in zip(*ft_words)}
ft_words = sorted(ft_word_dict.keys(), key=lambda x: ft_word_dict[x])

len(ft_words)

2519370

### Get Imagenet Classes

In [15]:
from fastai.io import get_data

In [16]:
CLASSES_FN = 'imagenet_class_index.json'
get_data(f'http://files.fast.ai/models/{CLASSES_FN}', TMP_PATH/CLASSES_FN)

imagenet_class_index.json: 41.0kB [00:00, 162kB/s]                             


### Get all nouns in English (WORDNET)

In [17]:
WORDS_FN = 'classids.txt'
get_data(f'http://files.fast.ai/data/{WORDS_FN}', PATH/WORDS_FN)

classids.txt: 1.74MB [00:01, 1.70MB/s]                            


### Create imagenet class number to words

In [18]:
class_dict = json.load((TMP_PATH/CLASSES_FN).open())
classids_1k = dict(class_dict.values())
nclass = len(class_dict); nclass

1000

In [19]:
class_dict['0']

['n01440764', 'tench']

### Wordnet class number to Nouns

In [20]:
classid_lines = (PATH/WORDS_FN).open().readlines()
classid_lines[:5]

['n00001740 entity\n',
 'n00001930 physical_entity\n',
 'n00002137 abstraction\n',
 'n00002452 thing\n',
 'n00002684 object\n']

In [21]:
classids = dict(l.strip().split() for l in classid_lines)
len(classids),len(classids_1k)

(82115, 1000)

#### Look up all teh nouns in FastText

In [22]:
lc_vec_d = {w.lower(): ft_vecs.get_word_vector(w) for w in ft_words[-1000000:]}
syn_wv = [(k, lc_vec_d[v.lower()]) for k,v in classids.items()
          if v.lower() in lc_vec_d]
syn_wv_1k = [(k, lc_vec_d[v.lower()]) for k,v in classids_1k.items()
          if v.lower() in lc_vec_d]
syn2wv = dict(syn_wv)
len(syn2wv)

49469

#### Save the lookups 

In [23]:
pickle.dump(syn2wv, (TMP_PATH/'syn2wv.pkl').open('wb'))
pickle.dump(syn_wv_1k, (TMP_PATH/'syn_wv_1k.pkl').open('wb'))

### CHECKPOINT load 

In [24]:
syn2wv = pickle.load((TMP_PATH/'syn2wv.pkl').open('rb'))
syn_wv_1k = pickle.load((TMP_PATH/'syn_wv_1k.pkl').open('rb'))

#### Due to Imagenet Localization data = 157GB, will not run the rest of this code

In [None]:
images = []
img_vecs = []

for d in (PATH/'train').iterdir():
    if d.name not in syn2wv: continue
    
    # grab the fast txt word vector
    vec = syn2wv[d.name]
    for f in d.iterdir():
        images.append(str(f.relative_to(PATH)))
        img_vecs.append(vec)

n_val=0
for d in (PATH/'valid').iterdir():
    if d.name not in syn2wv: continue
    vec = syn2wv[d.name]
    for f in d.iterdir():
        images.append(str(f.relative_to(PATH)))
        img_vecs.append(vec)
        n_val += 1

In [None]:
img_vecs = np.stack(img_vecs)
img_vecs.shapeb

In [None]:
pickle.dump(images, (TMP_PATH/'images.pkl').open('wb'))
pickle.dump(img_vecs, (TMP_PATH/'img_vecs.pkl').open('wb'))

In [None]:
# load the images for ImageNet
images = pickle.load((TMP_PATH/'images.pkl').open('rb'))

# have the corresponding vector for each image
img_vecs = pickle.load((TMP_PATH/'img_vecs.pkl').open('rb'))

## Create the model architecture + datasets

In [None]:
n = len(images); n
val_idxs = list(range(n-28650, n))

tfms = tfms_from_model(arch, 224, transforms_side_on, max_zoom=1.1)

# we can pass all the names from imagenet + word vecs
# then pass the indexes
# continuous = True - since we are predicting vectors
md = ImageClassifierData.from_names_and_array(PATH, images, img_vecs, val_idxs=val_idxs,
        classes=None, tfms=tfms, continuous=True, bs=256)


"""
arch     - resnet 50
md.c     - how many classes
is_multi - not multiclass
is_reg   - is regression
xtra_fc  - extra fully connected layers
ps       - how much dropout do you want?
*note no softmax
"""
arch = resnet50
models = ConvnetBuilder(arch, md.c, is_multi=False, is_reg=True, xtra_fc=[1024], ps=[0.2,0.2])
learn = ConvLearner(md, models, precompute=True)
learn.opt_fn = partial(optim.Adam, betas=(0.9,0.99))


# loss function - L1 loss is the difference
# but since we are doing high-dimensional vectors, most of the items
# are on the outside and the distance metric isn't the best metricb
def cos_loss(inp,targ): return 1 - F.cosine_similarity(inp,targ).mean()
learn.crit = cos_loss

### Train the model with `precompute=True` to cut down on training time 

Quoted at 1+ hour length

In [None]:
learn.lr_find(start_lr=1e-4, end_lr=1e15)
learn.sched.plot()

lr = 1e-2
wd = 1e-7
learn.precompute=True
learn.fit(lr, 1, cycle_len=20, wds=wd, use_clr=(20,10))

learn.bn_freeze(True)
learn.fit(lr, 1, cycle_len=20, wds=wd, use_clr=(20,10))

lrs = np.array([lr/1000,lr/100,lr])
learn.precompute=False
learn.freeze_to(1)

learn.save('pre0')
learn.load('pre0')

# Image Searching

In [None]:
syns, wvs = list(zip(*syn_wv_1k))
wvs = np.array(wvs)

%time pred_wv = learn.predict()

#### Let's take a look at some of the pictures

In [None]:
denorm = md.val_ds.denorm

def show_img(im, figsize=None, ax=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im)
    ax.axis('off')
    return ax

def show_imgs(ims, cols, figsize=None):
    fig,axes = plt.subplots(len(ims)//cols, cols, figsize=figsize)
    for i,ax in enumerate(axes.flat): show_img(ims[i], ax=ax)
    plt.tight_layout()

start=300
show_imgs(denorm(md.val_ds[start:start+25][0]), 5, (10,10))

<img src='https://snag.gy/OtP8k1.jpg' style='width:700px'>

### Use Nearest Neighbors search - 300D vector, what are the closest neighbors?

In [25]:
# super fast library, that searches very quickly
import nmslib

def create_index(a):
    index = nmslib.init(space='angulardist')
    index.addDataPointBatch(a)
    index.createIndex()
    return index

def get_knns(index, vecs):
     return zip(*index.knnQueryBatch(vecs, k=10, num_threads=4))

def get_knn(index, vec): return index.knnQuery(vec, k=10)

ModuleNotFoundError: No module named 'nmslib'

In [None]:
nn_wvs = create_index(wvs)
idxs,dists = get_knns(nn_wvs, pred_wv)
[[classids[syns[id]] for id in ids[:3]] for ids in idxs[start:start+10]]

### What if we now bring in WordNet

In [None]:
all_syns, all_wvs = list(zip(*syn2wv.items()))
all_wvs = np.array(all_wvs)

nn_allwvs = create_index(all_wvs)
idxs,dists = get_knns(nn_allwvs, pred_wv)
[[classids[all_syns[id]] for id in ids[:3]] for ids in idxs[start:start+10]]

# Text --> Image Search

In [None]:
nn_predwv = create_index(pred_wv)
en_vecd = pickle.load(open(TRANS_PATH/'wiki.en.pkl','rb'))

## get the vector for boat
vec = en_vecd['boat']
idxs,dists = get_knn(nn_predwv, vec)

# then we only pull images who's  vector is close to our 'boat' vector
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));

<img src='https://snag.gy/bsOHQ4.jpg'>

In [None]:
vec = (en_vecd['engine'] + en_vecd['boat'])/2
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));

<img src='https://snag.gy/eqK8dz.jpg'>

In [None]:
vec = (en_vecd['sail'] + en_vecd['boat'])/2
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));

<img src='https://snag.gy/Bz6Hsw.jpg'>