## Load the Dataset
MatchZoo expect a list of *Quintuple* as training data. The corresponded columns are `(text_left_id, text_right_id, text_left, text_right, label)`. For Information Retrieval task, `text_left` is referred as `query`, and `text_right` is document.

For the test case, MatchZoo expect a list of *Quadruple* (we do not need labels) as input.

MatchZoo expect a list of *Quintuple* as training data:

```python
train = [('qid0', 'did0', 'query 0', 'document 0', 'label 0'),
         ('qid0', 'did1', 'query 0', 'document 1', 'label 1'),
          ...,
         ('qid1', 'did2', 'query 1', 'document 2', 'label 3')]
```

The corresponded columns are `(text_left_id, text_right_id, text_left, text_right, label)`. For Information Retrieval task, *text_left* is referred as *query*, and *text_right* is document.

For the test case, MatchZoo expect a list of *Quadruple* (we do not need labels) as input:

```python
test = [('qid9', 'did5', 'query 9', 'document 5'),
         ...,
        ('qid2', 'did7', 'query 2', 'document 7')]
```

In [4]:
def read_data(path, stage):
    def scan_file():
        with open(path) as in_file:
            next(in_file)  # skip header
            for l in in_file:
                yield l.strip().split('\t')
    if stage == 'train':
        return [(qid, did, q, d, label) for qid, did, q, d, label in scan_file()]
    elif stage == 'predict':
        return [(elem[0], elem[1], elem[2], elem[3]) for elem in scan_file()]

train = read_data('data/matchzoo_input.txt', stage='train')
predict  = read_data('data/matchzoo_predict.txt', stage='predict')
rank = read_data('data/matchzoo_rank.txt', stage='predict')

In [5]:
print(train[0])
print(predict[0])
print(rank[0])

('350', 'FT934-11789', 'Health and Computer Terminals', "11 18 20,000 29 70 93 931029 _an a a a a a a a a a a action action after against against agency agree also although an an and and and and and and and and and and and anything appeal are arms as as as as as as as ascribe at at at authentic award be be because because been being being bernard between books both britain brought but by by by by by by care case case case case cast causal cause charter claim claim claim clerical colleague come comp company company computer computer concept condition condition condition conditions conditions confidence confuse considering continue correspondent costs could could country court court court court court court court criticise damages damages damages describe disappointed dismiss disorder dj2dcad8ft doubt down due ec editor elbow emergence emotional employ employee employer even exist expert factor felt financial first for for for for former ft future future gbz go greatest had had hand he he

## Preprocessing

In [6]:
from matchzoo import preprocessor
dssm_preprocessor = preprocessor.DSSMPreprocessor()
datapack_train = dssm_preprocessor.fit_transform(train, stage='train')

Using TensorFlow backend.


ModuleNotFoundError: No module named 'tensorflow'

In [173]:
type(datapack_train)

matchzoo.datapack.DataPack

In [174]:
# pre-processed records including index and processed text to store `text_left` and `id_left`
datapack_train.left.head()

Unnamed: 0_level_0,text_left,length_left
id_left,Unnamed: 1_level_1,Unnamed: 2_level_1
350,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",11196


In [175]:
# pre-processed records including index and processed text to store `text_right` and `id_right`
datapack_train.right.head()

Unnamed: 0_level_0,text_right,length_right
id_right,Unnamed: 1_level_1,Unnamed: 2_level_1
FT934-11789,"[0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",11196
LA091090-0108,"[0.0, 0.0, 2.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",11196
LA120789-0021,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",11196
LA031990-0076,"[0.0, 2.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, ...",11196
FT921-12910,"[0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...",11196


In [176]:
# pre-processed records including index and index mapping `id_left` and `id_right`
datapack_train.relation.head()

Unnamed: 0,id_left,id_right,label
0,350,FT934-11789,1
1,350,LA091090-0108,1
2,350,LA120789-0021,1
3,350,LA031990-0076,1
4,350,FT921-12910,1


In [177]:
# other information stored during the pre-processing process
datapack_train.context.keys()

dict_keys(['term_index', 'input_shapes'])

In [178]:
# vocabulary size
len(datapack_train.context['term_index'])

11195

In [179]:
# since DSSM models' input shapes are dynamic
# (depend on the generated tri-letters)
# so we have to calculate shapes during the pre-processing process
datapack_train.context['input_shapes']

[(11196,), (11196,)]

## Data Generation

In [180]:
from matchzoo import generators
from matchzoo import tasks
generator_train = generators.PointGenerator(
    inputs=datapack_train, task=tasks.Ranking(), batch_size=64, stage='train')
#generator_predict = generators.PointGenerator(
#   inputs=datapack_predict, task=tasks.Ranking(), batch_size=64, stage='predict')

## Training

In [181]:
from matchzoo import models, load_model
from matchzoo import losses
from matchzoo import tasks
from matchzoo import metrics
dssm_model = models.DSSMModel()

In [182]:
# handle dynamic input shapes of DSSM
input_shapes = datapack_train.context['input_shapes']
dssm_model.params['input_shapes'] = input_shapes

In [183]:
dssm_model.params['task'] = tasks.Ranking()
dssm_model.params['task'].metrics = ['mae', 'map']

In [184]:
dssm_model.guess_and_fill_missing_params()
print(dssm_model.params)

name                          DSSMModel
model_class                   <class 'matchzoo.models.dssm_model.DSSMModel'>
input_shapes                  [(11196,), (11196,)]
task                          <matchzoo.tasks.ranking.Ranking object at 0x1473e1ef0>
optimizer                     adam
w_initializer                 glorot_normal
b_initializer                 zeros
dim_fan_out                   128
dim_hidden                    300
activation_hidden             tanh
num_hidden_layers             2


In [185]:
dssm_model.build()
dssm_model.compile()
dssm_model.fit_generator(generator_train, steps_per_epoch=20, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1413a4048>

In [189]:
X, Y = generator_train[0]
dssm_model.evaluate(X, Y)



{'loss': 0.0005072275525890291,
 'mean_absolute_error': 0.008711242116987705,
 'mean_average_precision(0)': 1.0}

## Prediction Function

In [None]:
global toprank, kthrank, 
datapack_rank = dssm_preprocessor.fit_transform(rank, stage='predict')
    generator_rank = generators.PointGenerator(
        inputs=datapack_predict, task=tasks.Ranking(), batch_size=len(doc_text), stage='predict')
X_rank, _ = generator_predict[0]
k = 10
ranking = dssm_model.predict(X_rank)
rank_list = [r[0] for r in ranking]
rank_list.sort(reverse=True)
toprank = rank_list[0]
kthrank = rank_list[k]

In [202]:
import numpy as np

def predict_proba(doc_text):
    predict_data = list()
    count = 1
    did_list = list()
    for doc in doc_text:
        did_list.append(did + "_PRED_"+str(count))
        predict_data.append((qid, did + "_PRED_"+ str(count), query, doc))
        count += 1
        
    datapack_predict = dssm_preprocessor.fit_transform(predict_data, stage='predict')
    generator_predict = generators.PointGenerator(
        inputs=datapack_predict, task=tasks.Ranking(), batch_size=len(doc_text), stage='predict')
    X_predict, _ = generator_predict[0]
    
    pred = dssm_model.predict(X_predict)
    pred_list = [p[0] for p in pred]
    pdoclist = list(zip(did_list, pred_list))
#     pdoclist.sort(key=lambda x: x[1], reverse = True)
    
#     k = len(doc_text) // 10
#     topscore = pdoclist[0][1]
#     kscore = pdoclist[k][1]
    
    newdoclist = list()
    for i in range(len(pdoclist)):
        if pdoclist[i][1] > kthscore:
            newdoclist.append((pdoclist[i][0], 1))
        else:
            newdoclist.append((pdoclist[i][0],0))
            
#     newdoclist.sort(key=lambda x:x[0])
    prob = [(1 - elem[1], elem[1]) for elem in newdoclist]
#     print(len(prob))   
#     print(prob)
    return np.array(prob)


## Lime Initialization 

In [203]:
from lime.lime_text import LimeTextExplainer
import re

global qid, query, did
tokenizer = lambda doc: re.compile(r"(?u)\b\w\w+\b").findall(doc)
for row in train:
    (qid, did, query, document_text, label) = row
    explainer = LimeTextExplainer(class_names=["irrelevant", "relevant"], split_expression=tokenizer)
    exp = explainer.explain_instance(document_text, predict_proba, num_features=6)
    print(exp.as_list())


Start processing input data for predict stage.

0it [00:00, ?it/s][A
1it [00:00, 357.54it/s][A
0it [00:00, ?it/s][A
31it [00:00, 302.57it/s][A
66it [00:00, 314.70it/s][A
102it [00:00, 326.42it/s][A
136it [00:00, 328.12it/s][A
170it [00:00, 330.74it/s][A
207it [00:00, 340.10it/s][A
248it [00:00, 357.54it/s][A
285it [00:00, 360.03it/s][A
320it [00:00, 352.00it/s][A
364it [00:01, 371.87it/s][A
401it [00:01, 362.31it/s][A
443it [00:01, 377.37it/s][A
481it [00:01, 373.56it/s][A
519it [00:01, 362.46it/s][A
560it [00:01, 374.50it/s][A
603it [00:01, 387.09it/s][A
643it [00:01, 389.68it/s][A
684it [00:01, 392.84it/s][A
724it [00:01, 382.34it/s][A
768it [00:02, 394.82it/s][A
809it [00:02, 397.27it/s][A
849it [00:02, 386.76it/s][A
888it [00:02, 369.48it/s][A
926it [00:02, 359.59it/s][A
963it [00:02, 355.31it/s][A
999it [00:02, 347.29it/s][A
1034it [00:02, 340.15it/s][A
1069it [00:02, 341.02it/s][A
1111it [00:03, 359.03it/s][A
1151it [00:03, 367.53it/s][A
1189it [0

[('producer', -0.04499260845471026), ('could', 0.033609363656580946), ('well', 0.03219270320107857), ('that', -0.032107310274637785), ('offer', -0.029095985311904406), ('only', 0.028932719112119233)]


Start processing input data for predict stage.

0it [00:00, ?it/s][A
1it [00:00, 455.16it/s][A
0it [00:00, ?it/s][A
9it [00:00, 86.50it/s][A
19it [00:00, 89.52it/s][A
30it [00:00, 94.15it/s][A
39it [00:00, 90.39it/s][A
50it [00:00, 95.22it/s][A
59it [00:00, 90.91it/s][A
69it [00:00, 92.92it/s][A
78it [00:00, 91.75it/s][A
88it [00:00, 93.37it/s][A
98it [00:01, 90.12it/s][A
109it [00:01, 93.99it/s][A
119it [00:01, 94.73it/s][A
129it [00:01, 95.49it/s][A
139it [00:01, 95.18it/s][A
149it [00:01, 92.11it/s][A
159it [00:01, 88.51it/s][A
168it [00:01, 88.09it/s][A
181it [00:01, 96.02it/s][A
191it [00:02, 95.04it/s][A
201it [00:02, 95.30it/s][A
211it [00:02, 95.36it/s][A
222it [00:02, 96.44it/s][A
233it [00:02, 100.06it/s][A
244it [00:02, 96.90it/s] [A
256it [00:02, 102.75it/s][A
267it [00:02, 100.51it/s][A
279it [00:02, 105.24it/s][A
290it [00:03, 98.97it/s] [A
301it [00:03, 92.06it/s][A
313it [00:03, 96.70it/s][A
323it [00:03, 93.62it/s][A
333it [00:03, 92.1

[('glare', -0.04167734696674195), ('latest', -0.03816768657701775), ('and', -0.036245562855794354), ('color', 0.03332728953536575), ('subside', 0.030957281445283404), ('when', 0.0304228538274462)]


Start processing input data for predict stage.

0it [00:00, ?it/s][A
1it [00:00, 520.32it/s][A
0it [00:00, ?it/s][A
75it [00:00, 745.20it/s][A
158it [00:00, 766.67it/s][A
239it [00:00, 778.12it/s][A
326it [00:00, 801.85it/s][A
413it [00:00, 819.83it/s][A
497it [00:00, 823.69it/s][A
580it [00:00, 821.74it/s][A
667it [00:00, 835.17it/s][A
753it [00:00, 839.42it/s][A
835it [00:01, 822.34it/s][A
916it [00:01, 807.72it/s][A
996it [00:01, 781.89it/s][A
1081it [00:01, 800.73it/s][A
1165it [00:01, 810.17it/s][A
1248it [00:01, 815.45it/s][A
1333it [00:01, 823.87it/s][A
1416it [00:01, 809.80it/s][A
1497it [00:01, 809.29it/s][A
1582it [00:01, 818.12it/s][A
1664it [00:02, 807.89it/s][A
1745it [00:02, 798.00it/s][A
1825it [00:02, 795.99it/s][A
1905it [00:02, 769.57it/s][A
1984it [00:02, 772.24it/s][A
2065it [00:02, 783.19it/s][A
2150it [00:02, 801.94it/s][A
2231it [00:02, 794.65it/s][A
2318it [00:02, 814.55it/s][A
2400it [00:02, 800.14it/s][A
2481it [00:03, 796.98it/

[('employee', -0.02804804636878338), ('mandate', 0.026286065150812783), ('times', 0.02414927480481701), ('city', 0.023951349275549722), ('nation', -0.02262026167516275), ('000', -0.0213634961683032)]


Start processing input data for predict stage.

0it [00:00, ?it/s][A
1it [00:00, 629.30it/s][A
0it [00:00, ?it/s][A
1it [00:00,  9.22it/s][A
9it [00:00, 12.52it/s][A
17it [00:00, 16.58it/s][A
27it [00:00, 22.08it/s][A
35it [00:00, 27.97it/s][A
43it [00:00, 34.36it/s][A
50it [00:00, 40.25it/s][A
59it [00:00, 48.14it/s][A
69it [00:00, 55.41it/s][A
77it [00:01, 60.46it/s][A
85it [00:01, 64.05it/s][A
93it [00:01, 67.84it/s][A
104it [00:01, 76.06it/s][A
113it [00:01, 78.51it/s][A
122it [00:01, 76.59it/s][A
131it [00:01, 79.19it/s][A
140it [00:01, 79.99it/s][A
149it [00:01, 80.07it/s][A
160it [00:02, 83.72it/s][A
169it [00:02, 84.60it/s][A
178it [00:02, 84.66it/s][A
187it [00:02, 80.64it/s][A
197it [00:02, 83.28it/s][A
206it [00:02, 84.31it/s][A
216it [00:02, 87.21it/s][A
225it [00:02, 83.28it/s][A
234it [00:02, 80.11it/s][A
247it [00:03, 87.75it/s][A
257it [00:03, 84.59it/s][A
266it [00:03, 83.58it/s][A
275it [00:03, 82.65it/s][A
284it [00:03, 78.07it/s][A

4874it [01:05, 63.72it/s][A
4884it [01:05, 70.09it/s][A
4892it [01:05, 71.61it/s][A
4900it [01:05, 73.32it/s][A
4908it [01:05, 73.50it/s][A
4916it [01:05, 71.62it/s][A
4924it [01:05, 73.25it/s][A
4932it [01:06, 74.46it/s][A
4940it [01:06, 68.23it/s][A
4950it [01:06, 73.15it/s][A
4959it [01:06, 76.15it/s][A
4967it [01:06, 75.60it/s][A
4975it [01:06, 73.83it/s][A
4983it [01:06, 74.64it/s][A
4992it [01:06, 78.30it/s][A
5000it [01:06, 74.70it/s][A

[('recently', 0.03610064492401655), ('steve', 0.03576243833882852), ('surmise', -0.03462137910159381), ('important', -0.03179451819662839), ('mega', -0.03048926865087421), ('george', -0.030428452921853803)]


Start processing input data for predict stage.

0it [00:00, ?it/s][A
1it [00:00, 133.76it/s][A
0it [00:00, ?it/s][A
16it [00:00, 159.61it/s][A
32it [00:00, 157.63it/s][A
44it [00:00, 143.25it/s][A
55it [00:00, 129.08it/s][A
65it [00:00, 56.40it/s] [A
75it [00:00, 64.85it/s][A
93it [00:01, 79.88it/s][A
121it [00:01, 101.64it/s][A
150it [00:01, 126.14it/s][A
177it [00:01, 149.41it/s][A
203it [00:01, 171.18it/s][A
229it [00:01, 190.13it/s][A
274it [00:01, 229.46it/s][A
313it [00:01, 260.96it/s][A
346it [00:01, 259.79it/s][A
377it [00:01, 249.59it/s][A
406it [00:02, 250.23it/s][A
439it [00:02, 268.73it/s][A
468it [00:02, 270.68it/s][A
502it [00:02, 287.75it/s][A
533it [00:02, 263.23it/s][A
576it [00:02, 296.20it/s][A
615it [00:02, 315.16it/s][A
649it [00:02, 272.68it/s][A
679it [00:03, 262.13it/s][A
711it [00:03, 276.32it/s][A
741it [00:03, 143.26it/s][A
765it [00:03, 162.87it/s][A
807it [00:03, 199.01it/s][A
847it [00:03, 233.16it/s][A
890it [00:04, 268.68

[('screen', 0.03369836144881544), ('mr', 0.03311800198568993), ('equipment', -0.028626011696739936), ('ucw', -0.02643550214992452), ('consultation', -0.026410826795520658), ('vdu', 0.026277978047998817)]


Start processing input data for predict stage.

0it [00:00, ?it/s][A
1it [00:00, 351.40it/s][A
0it [00:00, ?it/s][A
79it [00:00, 779.43it/s][A
150it [00:00, 757.14it/s][A
231it [00:00, 770.28it/s][A
314it [00:00, 784.88it/s][A
403it [00:00, 811.07it/s][A
481it [00:00, 801.18it/s][A
570it [00:00, 824.90it/s][A
659it [00:00, 841.28it/s][A
747it [00:00, 849.86it/s][A
831it [00:01, 844.01it/s][A
914it [00:01, 836.88it/s][A
1001it [00:01, 844.14it/s][A
1085it [00:01, 827.21it/s][A
1168it [00:01, 804.10it/s][A
1254it [00:01, 817.99it/s][A
1340it [00:01, 827.02it/s][A
1426it [00:01, 836.13it/s][A
1510it [00:01, 831.06it/s][A
1594it [00:01, 829.41it/s][A
1683it [00:02, 845.24it/s][A
1768it [00:02, 844.72it/s][A
1853it [00:02, 835.10it/s][A
1937it [00:02, 819.60it/s][A
2022it [00:02, 826.71it/s][A
2107it [00:02, 831.74it/s][A
2191it [00:02, 827.08it/s][A
2275it [00:02, 830.77it/s][A
2359it [00:02, 800.95it/s][A
2440it [00:02, 785.57it/s][A
2519it [00:03, 779.39it

[('agno', -0.03613075067880314), ('will', -0.030454284221461292), ('into', 0.025084886355885103), ('to', 0.02301383448046211), ('about', 0.022425418965871795), ('nation', 0.02206508165810989)]


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  left['length_left'] = left.apply(lambda r: len(r['text_left']), axis=1)
Start processing input data for predict stage.

0it [00:00, ?it/s][A
1it [00:00, 406.19it/s][A
0it [00:00, ?it/s][A
48it [00:00, 475.40it/s][A
103it [00:00, 493.91it/s][A
152it [00:00, 491.60it/s][A
205it [00:00, 500.90it/s][A
261it [00:00, 515.56it/s][A
322it [00:00, 538.46it/s][A
374it [00:00, 529.22it/s][A
425it [00:00, 521.69it/s][A
479it [00:00, 523.91it/s][A
537it [00:01, 539.13it/s][A
590it [00:01, 518.29it/s][A
645it [00:01, 523.60it/s][A
706it [00:01, 545.72it/s][A
761it [00:01, 541.58it/s][A
816it [00:01, 528.17it/s][A
869it [00:01, 511.63it/s][A
922it [00:01, 516.69it/s][A
974it [00:01, 510.04it/s][A
1026it [00:01, 506.35it/s][A
1077it [00

[('indefinitely', 0.03399273436291269), ('receive', 0.03087812458887266), ('meet', -0.029174408174068224), ('liability', -0.027955117341882524), ('involved', 0.027404939889662767), ('hold', -0.02698216686936044)]


Start processing input data for predict stage.

0it [00:00, ?it/s][A
1it [00:00, 546.70it/s][A
0it [00:00, ?it/s][A
16it [00:00, 152.93it/s][A
37it [00:00, 165.94it/s][A
59it [00:00, 179.05it/s][A
81it [00:00, 188.87it/s][A
104it [00:00, 198.27it/s][A
129it [00:00, 209.65it/s][A
150it [00:00, 209.51it/s][A
173it [00:00, 214.74it/s][A
197it [00:00, 220.23it/s][A
219it [00:01, 220.07it/s][A
241it [00:01, 217.31it/s][A
263it [00:01, 217.00it/s][A
286it [00:01, 219.63it/s][A
311it [00:01, 227.12it/s][A
334it [00:01, 225.52it/s][A
360it [00:01, 232.43it/s][A
384it [00:01, 233.93it/s][A
408it [00:01, 232.01it/s][A
433it [00:01, 234.08it/s][A
458it [00:02, 236.52it/s][A
485it [00:02, 244.64it/s][A
510it [00:02, 236.48it/s][A
535it [00:02, 239.26it/s][A
560it [00:02, 230.46it/s][A
584it [00:02, 220.68it/s][A
607it [00:02, 217.96it/s][A
632it [00:02, 226.43it/s][A
655it [00:02, 223.71it/s][A
679it [00:03, 227.71it/s][A
705it [00:03, 236.29it/s][A
730it [00:03, 2

KeyboardInterrupt: 

In [19]:
X_predict, _ = generator_predict[0]
pred = dssm_model.predict(X)
for id_left, id_right, pred, _ in zip(X_predict.id_left, X_predict.id_right, pred, range(10)):
    print("{}/{} is predicted as {}".format(id_left, id_right, pred))

Q1733/D1642-14 is predicted as [0.14073879]
Q47/D47-10 is predicted as [0.36516124]
Q2766/D1764-5 is predicted as [-0.02338599]
Q1326/D1268-0 is predicted as [0.01347637]
Q744/D722-12 is predicted as [-0.09433301]
Q952/D920-3 is predicted as [0.7045051]
Q2189/D206-7 is predicted as [0.14717689]
Q1586/D1504-2 is predicted as [-0.33544147]
Q1416/D1349-22 is predicted as [0.08000302]
Q312/D311-2 is predicted as [0.25624424]


#### Model Persistence

You can persist your trained model using `model.save()` and `load_model` function:

In [20]:
dssm_model.save('/tmp/my_dssm_model')
loaded_dssm_model = load_model('/tmp/my_dssm_model')

In [21]:
(loaded_dssm_model.predict(X) == dssm_model.predict(X)).all()

True

## Reference

[Huang et al. 2013] Po-Sen Huang, Xiaodong He, Jianfeng Gao, Li Deng, Alex Acero, and Larry Heck. 2013. Learning deep structured semantic models for web search using clickthrough data. In Proc. CIKM. ACM, 2333–2338.