# NLP data
>and all those embedding funs

[Netflix dataset](https://www.kaggle.com/shivamb/netflix-shows) on kaggle

* Create a pytorch dataset for text
* A BiLSTM model to predict multiple genre
* Encode the text to vectors using the model we trained
* Search the closest description

In [1]:
# default_exp data.nlp

In [2]:
# export
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
from pathlib import Path
# Any results you write to the current directory are saved as output.

In [3]:
!pip install forgebox



In [4]:
DATA = Path("data/netflix_titles.csv")

df = pd.read_csv(DATA)

df.sample(10)

Unnamed: 0,show_id,type,title,director,cast,country,date_added,release_year,rating,duration,listed_in,description
4476,80145625,Movie,God of War,Gordon Chan,"Vincent Zhao, Sammo Kam-Bo Hung, Regina Wan, Y...","China, Hong Kong","December 16, 2017",2017,NR,129 min,"Action & Adventure, International Movies",A maverick leader and a clever young general t...
4327,81094074,Movie,"El Pepe, a Supreme Life",Emir Kusturica,"José Mujica, Emir Kusturica","Argentina, Uruguay, Serbia","December 27, 2019",2018,TV-14,73 min,"Documentaries, International Movies","In this intimate documentary, former Uruguayan..."
6078,80094603,TV Show,Highway Thru Hell,,"Dave Pettitt, Jamie Davis, Adam Gazzola, Kevin...",Canada,"December 3, 2019",2016,TV-PG,3 Seasons,Reality TV,On the hazardous highways of Canada's interior...
5260,80044093,Movie,Team Foxcatcher,Jon Greenhalgh,,United States,"April 29, 2016",2016,TV-14,91 min,"Documentaries, Sports Movies","With never-before seen home video, this film r..."
5464,81021447,Movie,The Silence,John R. Leonetti,"Stanley Tucci, Kiernan Shipka, Miranda Otto, K...",Germany,"April 10, 2019",2019,TV-14,91 min,"Horror Movies, Thrillers",With the world under attack by deadly creature...
1697,81043473,Movie,SGT. Will Gardner,Max Martini,"Max Martini, Omari Hardwick, Lily Rabe, Elisab...",United States,"May 19, 2019",2019,TV-MA,125 min,Dramas,A homeless vet who has PTSD steals a motorcycl...
1178,80028357,Movie,"Love, Rosie",Christian Ditter,"Lily Collins, Sam Claflin, Christian Cooke, Ja...","Germany, United Kingdom","November 20, 2019",2014,R,103 min,"Comedies, International Movies, Romantic Movies","Over the years, as they come and go in each ot..."
893,80190103,Movie,Naan Sigappu Manithan,Thiru,"Vishal, Lakshmi Menon, Saranya Ponvannan, Jaya...",India,"October 1, 2018",2014,TV-MA,147 min,"Action & Adventure, Dramas, International Movies",After his sleeping disorder hinders him from p...
4075,80156767,Movie,La Última Fiesta,"Leandro Mark, Nicolás Silbert","Nicolás Vázquez, Alan Sabbagh, Benjamín Amadeo...",Argentina,"February 1, 2017",2016,TV-MA,104 min,"Comedies, International Movies",Three best buddies are thrown into a wild chas...
5783,80027373,TV Show,Oh No! It's an Alien Invasion,,"Al Mukadam, Dan Chameroy, Seán Cullen, Stacey ...",Canada,"May 31, 2015",2014,TV-Y7-FV,2 Seasons,"Kids' TV, TV Action & Adventure, TV Sci-Fi & F...",Nate and his Super Wicked Extreme Emergency Te...


### So... what Y?

In [5]:
df.listed_in.value_counts()

Documentaries                                                   299
Stand-Up Comedy                                                 273
Dramas, International Movies                                    248
Dramas, Independent Movies, International Movies                186
Comedies, Dramas, International Movies                          174
                                                               ... 
TV Dramas, TV Mysteries, TV Thrillers                             1
Kids' TV, TV Dramas, Teen TV Shows                                1
Romantic TV Shows, Spanish-Language TV Shows, TV Comedies         1
Classic & Cult TV, TV Horror, TV Mysteries                        1
International TV Shows, Spanish-Language TV Shows, TV Horror      1
Name: listed_in, Length: 461, dtype: int64

In [6]:
df.rating.value_counts()

TV-MA       2027
TV-14       1698
TV-PG        701
R            508
PG-13        286
NR           218
PG           184
TV-Y7        169
TV-G         149
TV-Y         143
TV-Y7-FV      95
G             37
UR             7
NC-17          2
Name: rating, dtype: int64

In [7]:
df["listed_in"] = df.listed_in.str\
.replace("&",",")\
.replace(" , ",",")\
.replace(" ,",",")\
.replace(", ",",")\
.replace(" , ",",")

In [8]:
genre = list(set(i.strip() for i in (",".join(list(df.listed_in))).split(",")))

In [9]:
print(f"Total genre: {len(genre)}\n")
for g in genre:
    print(g,end="\t")

Total genre: 49

Teen TV Shows	Nature TV	Romantic TV Shows	TV Sci-Fi	Action	Faith	Talk Shows	TV Shows	TV Dramas	Thrillers	Sci-Fi	Sports Movies	TV Thrillers	Movies	Cult Movies	Docuseries	TV Comedies	TV Action	Children	Classic Movies	Korean TV Shows	Dramas	TV Horror	Romantic Movies	Spirituality	International TV Shows	Independent Movies	Stand-Up Comedy	Science	Horror Movies	TV Mysteries	Music	Reality TV	Crime TV Shows	Adventure	Classic	Spanish-Language TV Shows	Family Movies	Documentaries	Anime Series	Musicals	Fantasy	Kids' TV	Anime Features	Comedies	British TV Shows	International Movies	LGBTQ Movies	Cult TV	

In [10]:
eye = np.eye(len(genre))
genre_dict = dict((v,eye[k]) for k,v in enumerate(genre))

def to_nhot(text):
    return np.sum(list(genre_dict[g.strip()] for g in text.split(",")),axis=0).astype(np.int)

df["genre"] = df.listed_in.apply(to_nhot)

In [11]:
PROCESSED = "processed.csv"

In [12]:
df.to_csv(PROCESSED,index = False)

### Process the text

In [13]:
# export
def split_df(df, valid=0.2, ensure_factor=2):
    """
    df: dataframe
    valid: valid ratio, default 0.1
    ensure_factor, ensuring the row number to be the multiplication of this factor, default 2
    return train_df, valid_df
    """
    split_ = (np.random.rand(len(df)) > valid)
    train_df = df[split_].sample(frac=1.).reset_index().drop("index", axis=1)
    valid_df = df[~split_].sample(frac=1.).reset_index().drop("index", axis=1)

    if ensure_factor:
        train_mod = len(train_df) % ensure_factor
        valid_mod = len(valid_df) % ensure_factor
        if train_mod: train_df = train_df[:-train_mod]
        if valid_mod: valid_df = valid_df[:-valid_mod]
    return train_df, valid_df

In [16]:
train_df,val_df = split_df(df,valid=0.1)
print(f"train:{len(train_df)}\tvalid:{len(val_df)}")

train:5624	valid:608


In [17]:
from nltk.tokenize import TweetTokenizer
tkz = TweetTokenizer()
def tokenize(txt):
    return tkz.tokenize(txt)

In [18]:
tokenize("A man returns home after being released from ")

['A', 'man', 'returns', 'home', 'after', 'being', 'released', 'from']

### Generate vocabulary map from material

In [19]:
# export 
from itertools import chain
from multiprocessing import Pool
from collections import Counter
from torch.utils.data.dataset import Dataset

class Vocab(object):
    def __init__(self, iterative, tokenize, max_vocab = 20000,nproc=10):
        """
        Count the most frequent words
        Make the word<=>index mapping
        """
        self.l = list(iterative)
        self.nproc = nproc
        self.max_vocab = max_vocab
        self.tokenize = tokenize
        self.word_beads = self.word_beads_()
        self.counter()
        
    def __len__(self):
        return len(self.words)
        
    def __repr__(self):
        return f"vocab {self.max_vocab}"
        
    def word_beads_(self):
        self.p = Pool(self.nproc)
        return list(chain(*list(self.p.map(self.tokenize,self.l))))
    
    def counter(self):
        vals = np.array(list((k,v) for k,v in dict(Counter(self.word_beads)).items()))
        self.words = pd.DataFrame({"tok":vals[:,0],"ct":vals[:,1]})
        self.words["ct"] = self.words["ct"].apply(int)
        self.words = self.words.sort_values(by= "ct",ascending=False)\
        .reset_index().drop("index",axis=1).head(self.max_vocab-2)
        self.words["idx"] = (np.arange(len(self.words))+2)
        self.words=pd.concat([self.words,pd.DataFrame({"tok":["<eos>","<mtk>"],"ct":[-1,-1],"idx":[0,1]})])
        return self.words
    
    def to_i(self):
        self.t2i = dict(zip(self.words["tok"],self.words["idx"]))
        def to_index(t):
            i = self.t2i.get(t)
            if i==None:
                return 1
            else:
                return i
        return to_index
    
    def to_t(self):
        return np.roll(self.words["tok"],2)
        

In [20]:
vocab = Vocab(df.description,tokenize=tokenize)

### Vocabulary build from training

In [21]:
vocab.words

Unnamed: 0,tok,ct,idx
0,four,99,2
1,drama,99,3
2,brother,97,4
3,himself,97,5
4,evil,95,6
...,...,...,...
19519,Gora,1,19521
19520,High-strung,1,19522
19521,out-of-the-way,1,19523
0,<eos>,-1,0


In [44]:
# export 

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data._utils.collate import default_collate

class seqData(Dataset):
    def __init__(self,lines,vocab,max_len=-1):
        """
        lines: iterative of text, eg. each element a sentence
        vocab:forge.data.nlp.Vocab
        max_len: max length
        """
        self.lines = list(lines)
        self.vocab = vocab
        self.to_i = np.vectorize(vocab.to_i())
        self.to_t = vocab.to_t()
        self.bs=1
        self.max_len=max_len
        
    def __len__(self):
        return len(self.lines)
    
    def __getitem__(self,idx):
        """
        Translate words to indices
        """
        line = self.lines[idx]
        words = self.vocab.tokenize(line)
        if self.max_len>2:
            words = words[:self.max_len-2]
        words = ["<eos>",]+words+["<eos>"]
        return self.to_i(np.array(words))
    
    def backward(self,seq):
        """
        This backward has nothing to do with gradrient
        Just to error proof the tokenized line
        """
        return " ".join(self.to_t[seq])

    def collate(self,rows):
        """
        this collate will pad any sentence that is less then the max length
        """
        line_len = torch.LongTensor(list(len(row) for row in rows));
        max_len = line_len.max()
        ones = torch.ones(max_len.item()).long()
        line_pad = max_len-line_len
        return torch.stack(list(torch.cat([torch.LongTensor(row),ones[:pad.item()]]) for row,pad in zip(rows,line_pad[:,None])))
    
class arrData(Dataset):
    def __init__(self, *arrs):
        self.arr = np.concatenate(arrs,axis=1)
    
    def __len__(self):
        return self.arr.shape[0]
    
    def __getitem__(self,idx):
        return self.arr[idx]
    
    def collate(self,rows):
        return default_collate(rows)

Build vocabulary and train dataset

In [45]:
vocab = Vocab(df.description,tokenize=tokenize)

train_seq = seqData(train_df.description,vocab)
train_y = arrData(np.stack(train_df.genre.values))

val_seq = seqData(val_df.description,vocab)
val_y = arrData(np.stack(val_df.genre.values))

Size of train dataset

In [46]:
len(train_seq),len(train_y)

(5624, 5624)

In [47]:
tokenized_line = train_seq[10]
tokenized_line

array([   0, 4908, 1613, 3527, 2113, 4873,  908, 4819, 8186,  569, 8480,
       8163, 1476,  579, 4989, 1460, 9109, 1544, 8850, 3102,  908, 3233,
        941, 5001, 1454, 2152, 3246, 5055, 1454, 8480, 4974, 8019, 8019,
       8589,  579,    0])

Reconstruct the sentence from indices

>**<mtk\>** means the missing tokens, for they are less frequent than we should hav cared

In [48]:
train_seq.backward(tokenized_line)

"<eos> Comedian Maria Bamford stars in a series inspired by her own life . It's the sometimes surreal story of a woman who loses – and then finds – her s * * t . <eos>"

### A custom made collate function

* Collate function will do the following:
>Make rows of dataset output into a batch of tensor

In [50]:
gen = iter(DataLoader(train_seq,batch_size=16, collate_fn=train_seq.collate))
next(gen).size()

torch.Size([16, 36])

In [63]:
# export
class fuse(Dataset):
    def __init__(self, *datasets):
        """
        A pytorch dataset combining the dataset
        :param datasets:
        """
        self.datasets = datasets
        length_s = set(list(len(d) for d in self.datasets))
        assert len(length_s) == 1, "dataset lenth not matched"
        self.length = list(length_s)[0]
        self.collates = list(i.collate if hasattr(i,"collate") else default_collate for i in datasets)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return tuple(d.__getitem__(idx) for d in self.datasets)
    
    def collate(self,rows):
        xs = list(zip(*rows))
        return tuple(func(x) for func, x in zip(self.collates,xs))

### Fusing data set

In [66]:
train_ds = fuse(train_seq,train_y)
val_ds = fuse(val_seq,val_y)

### Testing Generator

In [67]:
gen = iter(DataLoader(train_ds,batch_size=16, collate_fn=train_ds.collate))
x,y = next(gen)
print(x.shape,y.shape)

torch.Size([16, 36]) torch.Size([16, 49])


### Model

In [68]:
from torch import nn
import torch

In [69]:
class basicNLP(nn.Module):
    def __init__(self, hs):
        super().__init__()
        self.hs = hs
        self.emb = nn.Embedding(len(vocab),hs)
        self.rnn = nn.LSTM(input_size = hs,hidden_size = hs,batch_first = True)
        self.fc = nn.Sequential(*[
            nn.BatchNorm1d(hs*2),
            nn.ReLU(),
            nn.Linear(hs*2,hs*2),
            nn.BatchNorm1d(hs*2),
            nn.ReLU(),
            nn.Linear(hs*2,49),
        ])
        
    def encoder(self,x):
        x = self.emb(x)
        o1,(h1,c1) = self.rnn(x)
        # run sentence backward
        o2,(h2,c2) = self.rnn(x.flip(dims=[1]))
        return torch.cat([h1[0],h2[0]],dim=1)
        
    def forward(self,x):
        vec = self.encoder(x)
        return self.fc(vec)

In [70]:
model = basicNLP(100)

In [71]:
x[:2,:]

tensor([[    0,  1477,  8844,  2473, 10957, 10958,  4973,   942,  8502,  8945,
           940,  8844,   923,  2068,   908,  9196,   957,  2153,  8481,  8851,
          8873,  1454,  2152,  4741,  1460, 10959,   916,  6700,   579,     0,
             1,     1,     1,     1,     1,     1],
        [    0,  9024,   765,  4886,  8086,   948,  8491,  4790,  2467,  2112,
          1460,  8940,  8833,  8781,  5183,  1118,  3231,  8576,  9070, 12932,
           579,     0,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1]])

### What does embedding do?

In [72]:
x.shape,model.emb(x).shape

(torch.Size([16, 36]), torch.Size([16, 36, 100]))

### What does LSTM return?

For what is LSTM, read this [awesome blog](https://colah.github.io/posts/2015-08-Understanding-LSTMs/), from which I stole the following visualization from

#### In short version
RNN, it's about sharing model weights throughout temporal sequence, as convolusion share weights in spatial point of view
https://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png
* The above green "A" areas are shared linear layer
* GRU & LSTM are advanced version of RNN, with gate control
* The black arrows above in GRU & LSTM are controlled by gates
* Gates, are just linear layer with sigmoid activation $\sigma(x)$, its outputs are between (0,1), hence the name gate, the following illustration is one of the gates in a lstm cell, called input gate
https://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-focus-f.png
* Other gates control other things like should we forget the early part of then sentence, should we output this .etc

### In terms of code

In [38]:
%time 
output,(hidden_state, cell_state) = model.rnn(model.emb(x))
for t in (output,hidden_state, cell_state):
    print(t.shape)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 7.15 µs
torch.Size([16, 34, 100])
torch.Size([1, 16, 100])
torch.Size([1, 16, 100])


Disect the iteration through the sentence

In [39]:
%time
init_hidden = torch.zeros((1,16,100))
init_cell = torch.zeros((1,16,100))
last_h,last_c = init_hidden,init_cell
outputs = []
x_vec = model.emb(x)
for row in range(x.shape[1]):
    last_o, (last_h,last_c) = model.rnn(x_vec[:,row:row+1,:],(last_h,last_c))
    outputs.append(last_o)

CPU times: user 2 µs, sys: 1e+03 ns, total: 3 µs
Wall time: 6.44 µs


In [40]:
manual_iteration_result = torch.cat(outputs,dim=1)

In [41]:
manual_iteration_result.shape

torch.Size([16, 34, 100])

The 2 results are the same, of course, I thought manual python iteration is slower,but they are really close by the above test

In [42]:
(manual_iteration_result==output).float().mean()

tensor(1.)

### Training

In [43]:
lossf = nn.BCEWithLogitsLoss()

In [44]:
from forgebox.ftorch.train import Trainer
from forgebox.ftorch.callbacks import stat
from forgebox.ftorch.metrics import metric4_bi

In [45]:
model = model.cuda()

In [46]:
t = Trainer(train_ds, val_dataset=val_ds,batch_size=16,callbacks=[stat], val_callbacks=[stat] ,shuffle=True,)

Notice, The Trainer was not initiated with optimizer
            Use the following syntax to initialize optimizer
            t.opt["adm1"] = torch.optim.Adam(m1.parameters())
            t.opt["adg1"] = torch.optim.Adagrad(m2.parameters())
            


In [47]:
t.opt["adm1"] = torch.optim.Adam(model.parameters())

Combined collate function

In [48]:
t.train_data.collate_fn = combine_collate(pad_collate,default_collate)
t.val_data.collate_fn = combine_collate(pad_collate,default_collate)

In [49]:
@t.step_train
def train_step(self):
    self.opt.zero_all()
    x,y = self.data
    y_= model(x)
    loss = lossf(y_,y.float())
    loss.backward()
    self.opt.step_all()
    acc,rec,prec,f1 = metric4_bi(torch.sigmoid(y_),y)
    return dict((k,v.item()) for k,v in zip(["loss","acc","rec","prec","f1"],(loss,acc,rec,prec,f1)))
                
@t.step_val
def val_step(self):
    x,y = self.data
    y_= model(x)
    loss = lossf(y_,y.float())
    acc,rec,prec,f1 = metric4_bi(torch.sigmoid(y_),y)
    return dict((k,v.item()) for k,v in zip(["loss","acc","rec","prec","f1"],(loss,acc,rec,prec,f1)))

In [50]:
t.train(10)

HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.191503,0.943605,0.018355,0.25263,0.066362,0.0,174.5,0.012922
min,0.133593,0.498724,0.0,0.0,0.039216,0.0,0.0,0.0
max,0.717684,0.96301,0.452381,1.0,0.210526,0.0,349.0,4.522615


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.166195,0.94982,0.014063,0.455285,0.052917,0.0,20.0,0.016862
min,0.133625,0.933673,0.0,0.0,0.042553,0.0,0.0,0.0
max,0.208361,0.960459,0.0625,1.0,0.114286,0.0,40.0,0.691359


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.160953,0.949778,0.040362,0.503684,0.096832,1.0,174.5,0.013533
min,0.11375,0.936224,0.0,0.0,0.038462,1.0,0.0,0.0
max,0.204853,0.966837,0.228571,1.0,0.363636,1.0,349.0,4.736667


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.159649,0.949788,0.056723,0.472822,0.113787,1.0,20.0,0.013886
min,0.126214,0.941326,0.0,0.0,0.045455,1.0,0.0,0.0
max,0.18382,0.959184,0.162162,1.0,0.272727,1.0,40.0,0.569323


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.150287,0.951005,0.094498,0.619261,0.1644,2.0,174.5,0.013062
min,0.107497,0.936224,0.0,0.0,0.04,2.0,0.0,0.0
max,0.211989,0.966837,0.25641,1.0,0.4,2.0,349.0,4.571776


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.158233,0.9501,0.080853,0.499042,0.151939,2.0,20.0,0.013032
min,0.100352,0.938775,0.0,0.0,0.041667,2.0,0.0,0.0
max,0.199564,0.96301,0.2,1.0,0.333333,2.0,40.0,0.534303


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.13905,0.952826,0.155567,0.647484,0.247037,3.0,174.5,0.013327
min,0.102404,0.936224,0.0,0.0,0.038462,3.0,0.0,0.0
max,0.18937,0.970663,0.382353,1.0,0.510638,3.0,349.0,4.664611


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.159843,0.949633,0.105194,0.460275,0.17306,3.0,20.0,0.012334
min,0.120851,0.9375,0.0,0.0,0.040816,3.0,0.0,0.0
max,0.208686,0.964286,0.28125,0.909091,0.392157,3.0,40.0,0.505692


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.125611,0.955864,0.235954,0.693594,0.347217,4.0,174.5,0.012675
min,0.078176,0.938775,0.065217,0.333333,0.113208,4.0,0.0,0.0
max,0.186781,0.977041,0.516129,1.0,0.64,4.0,349.0,4.436386


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.165055,0.947704,0.137215,0.453738,0.206593,4.0,20.0,0.012974
min,0.10309,0.932398,0.043478,0.142857,0.070175,4.0,0.0,0.0
max,0.230252,0.959184,0.243243,1.0,0.36,4.0,40.0,0.531945


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.110611,0.95971,0.333478,0.72226,0.450937,5.0,174.5,0.013399
min,0.072956,0.938775,0.1,0.4,0.16,5.0,0.0,0.0
max,0.182841,0.977041,0.625,1.0,0.677966,5.0,349.0,4.689789


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.175897,0.946429,0.152444,0.391154,0.22165,5.0,20.0,0.012744
min,0.121843,0.934949,0.0,0.0,0.039216,5.0,0.0,0.0
max,0.224849,0.960459,0.342857,0.75,0.421053,5.0,40.0,0.522507


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.0963,0.963971,0.430203,0.756548,0.543497,6.0,174.5,0.012643
min,0.060433,0.941326,0.170732,0.473684,0.264151,6.0,0.0,0.0
max,0.181005,0.980867,0.727273,1.0,0.780488,6.0,349.0,4.425211


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.185505,0.946584,0.185838,0.430642,0.256144,6.0,20.0,0.012818
min,0.116973,0.927296,0.054054,0.125,0.075472,6.0,0.0,0.0
max,0.259556,0.964286,0.393939,0.692308,0.486486,6.0,40.0,0.525531


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.083081,0.967727,0.507944,0.780587,0.611133,7.0,174.5,0.014348
min,0.052558,0.942602,0.238095,0.484848,0.327869,7.0,0.0,0.0
max,0.140412,0.984694,0.777778,1.0,0.823529,7.0,349.0,5.021668


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.204599,0.945153,0.19655,0.418075,0.263833,7.0,20.0,0.012317
min,0.13229,0.924745,0.021277,0.071429,0.032787,7.0,0.0,0.0
max,0.391081,0.964286,0.333333,1.0,0.44,7.0,40.0,0.504995


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.072308,0.971904,0.582242,0.811531,0.674335,8.0,174.5,0.013954
min,0.044403,0.951531,0.342857,0.545455,0.421053,8.0,0.0,0.0
max,0.113198,0.989796,0.875,1.0,0.897436,8.0,349.0,4.884068


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.219491,0.941669,0.185854,0.345028,0.244243,8.0,20.0,0.015648
min,0.138023,0.927296,0.0,0.0,0.083333,8.0,0.0,0.0
max,0.326187,0.959184,0.341463,0.6,0.413793,8.0,40.0,0.641576


HBox(children=(FloatProgress(value=0.0, max=350.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.062107,0.975899,0.656051,0.835032,0.731127,9.0,174.5,0.013415
min,0.031462,0.957908,0.390244,0.62963,0.5,9.0,0.0,0.0
max,0.103414,0.992347,0.933333,1.0,0.914286,9.0,349.0,4.695169


HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))




Unnamed: 0,loss,acc,rec,prec,f1,epoch,iter,timestamp
mean,0.232948,0.942415,0.191221,0.342991,0.249418,9.0,20.0,0.01403
min,0.139447,0.92602,0.0,0.0,0.076923,9.0,0.0,0.0
max,0.320201,0.969388,0.411765,0.619048,0.464286,9.0,40.0,0.575232


### Search similar

In [51]:
model = model.eval()
dl = DataLoader(train_seq, batch_size=32, collate_fn=pad_collate)

In [52]:
text_gen = iter(dl)
result = []
for i in range(len(dl)):
    x=next(text_gen)
    x = x.cuda()
    x_vec = model.encoder(x)
    result.append(x_vec.cpu())

A vector representing each of the sentence

In [53]:
result_vec = torch.cat(result,dim=0).detach().numpy()
result_vec.shape

(5590, 200)

In [54]:
def to_idx(line):
    words = train_seq.vocab.tokenize(line)
    words = ["<eos>",]+words+["<eos>"]
    return train_seq.to_i(np.array(words))[None,:]

In [55]:
to_idx("this"), to_idx("to be or not to be")

(array([[  0, 913,   0]]),
 array([[   0, 2153, 5056, 1442, 1438, 2153, 5056,    0]]))

In [56]:
def to_vec(line):
    vec = torch.LongTensor(to_idx(line)).cuda()
    return model.encoder(vec).cpu().detach().numpy()

In [57]:
to_vec("this"), to_vec("to be or not to be")

(array([[-0.2864027 , -0.1188629 , -0.02617935, -0.24850701, -0.20578709,
          0.06889898,  0.05878495, -0.22658055,  0.15024377, -0.31303164,
          0.4958601 , -0.00204021,  0.17621423, -0.2538225 , -0.3451157 ,
         -0.23131247, -0.06265341, -0.17155428, -0.00899762, -0.2577241 ,
         -0.02896317, -0.4555603 ,  0.6856887 , -0.70418745,  0.11082602,
         -0.23981036, -0.21201135, -0.43933266, -0.40148616, -0.48364577,
          0.42605698,  0.41181046, -0.14798354,  0.05320957, -0.4300459 ,
         -0.06580015,  0.01534137,  0.02928652, -0.53414524, -0.02051809,
         -0.47986796, -0.12036817, -0.00229292,  0.2772647 ,  0.1102341 ,
         -0.40527657, -0.11229473, -0.42787483,  0.40304342,  0.3992268 ,
          0.1696693 , -0.6523132 , -0.3679182 ,  0.0087082 , -0.01391423,
         -0.21790238, -0.3263417 , -0.29073295,  0.45376575, -0.02547877,
         -0.08805989, -0.04059793,  0.32122698, -0.10253229,  0.29216015,
         -0.29081804,  0.7000031 , -0.

In [58]:
def l2norm(x):
    """
    L2 Norm
    """
    return np.linalg.norm(x,2,1).reshape(-1,1)

In [59]:
pd.set_option("max_colwidth",150)

def search(line):
    vec = to_vec(line)
    sim = ((vec* result_vec)/l2norm(result_vec)).sum(-1)
    return pd.DataFrame({"text":train_seq.lines,"sim":sim})\
        .sort_values(by="sim",ascending=False)

In [60]:
search("Experience our planet's natural beauty").head(10)

Unnamed: 0,text,sim
4699,"The year is 2041 and a dispute between a man and his wife has set the human race back. In this battle of the sexes, the primitive life isn't so si...",2.79601
1184,"Ya-nuo's been raised as a boy. Now at age 25, she's caught the eye of a triad leader's sister. But what happens when she reveals her true gender?",2.747936
47,"Scouted by a famous Spanish club, Valt Aoi heads to Spain. With their sights on the World League, he and his teammates face the European League fi...",2.684218
1116,"Led by seventh-grader C.J., three students who have been warned about the dangers of high school decide to make the best of their middle-school ye...",2.633265
5380,"Love is in the air as Zoe and friends go on a quest to find a fabled Maid's Stone. But when rivalry blinds them to danger, it's Raven to the rescue!",2.550404
4741,"The whole huggable gang is back, bringing tales of caring and sharing to a new generation. And now the Care Bear Cousins are here to join the fun!",2.54675
4609,"Captain Atomic – once a superhero, now a sock puppet – can only activate his powers with the help of Joey, his new partner and biggest fan.",2.545124
3574,"Sometimes being shady is the only way to survive, a fact these sneaky animal ""hustlers"" – including orcas, owls and otters – use to their advantage.",2.517003
2686,Comedian Maria Bamford stars in a series inspired by her own life. It's the sometimes surreal story of a woman who loses – and then finds – her s**t.,2.510907
1039,"Thom tells her grandson about his grandfather, a sailor who left 61 years ago to seek his fortune, and asks him to find out what happened to her s...",2.49974


In [61]:
search("love story,marriage, girl").head(10)

Unnamed: 0,text,sim
1588,"To her Indian parents' dismay, London-born Jasmeet ""Jazz"" Malhotra longs for everything Western, including her British boyfriend. On a family trip...",2.468083
4285,"Quick to throw punches in the name of justice, a young man must find a calmer way to win over the pacifist father of the girl he wishes to marry.",2.414226
504,"When Parisian Elsa gets hung up on her ex, her best friends secretly hire a male escort to help her move on. But their plan works a little too well.",2.391777
2595,"Motu and Patlu want to help a circus lion get back to the jungle. On the way, the three become caught up in an exciting adventure in the forest.",2.36902
652,"Hoping to find a magical root, a monster has captured farmers in the land of Vyom. It’s up to Bheem and the gang to foil his plan and save the kin...",2.363151
2042,"Tired of her passionless marriage, Marianne wants a separation from her husband, Gustav, who, in response, decides to make a big change of his own.",2.361425
1668,"In 1962 Brooklyn, a Puerto Rican teen who joins a gang is seduced by violence and heroin. But can his mother's love and faith in God save him?",2.344456
2478,"Maya finally hooks up with her online dream girl, only to discover she’s deeply involved with an older sugar daddy – a man Maya knows all too well.",2.337689
1324,"After an argument with her dad, a young woman from a family of macho truck drivers is kicked out of the home and must make her own success as a tr...",2.334984
177,"Little Singham is in London to meet the queen, but when the famed Kohinoor Diamond gets stolen, the kid cop goes on a wild, citywide hunt for the ...",2.310921


Well, usually it should be more accurate if we have more data