In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from importlib import reload 

In [2]:
from deeprank.dataset import DataLoader, PairGenerator, ListGenerator

In [3]:
loader = DataLoader('./config/letor07_mp_fold1.model')

[./data/letor/r5w/word_dict.txt]
	Word dict size: 193367
[./data/letor/r5w/qid_query.txt]
	Data size: 1692
[./data/letor/r5w/docid_doc.txt]
	Data size: 65323
[./data/letor/r5w/embed_wiki-pdc_d50_norm]
	Embedding size: 109282
Generate numpy embed: (193368, 50)


In [4]:
import json
letor_config = json.loads(open('./config/letor07_mp_fold1.model').read())

In [5]:
Letor07Path = letor_config['data_dir']

letor_config['fill_word'] = loader._PAD_
letor_config['embedding'] = loader.embedding
letor_config['feat_size'] = loader.feat_size
letor_config['vocab_size'] = loader.embedding.shape[0]
letor_config['embed_dim'] = loader.embedding.shape[1]

pair_gen = PairGenerator(rel_file=Letor07Path + '/relation.train.fold%d.txt'%(letor_config['fold']), 
                         config=letor_config)

[./data/letor/r5w/relation.train.fold1.txt]
	Instance size: 47828
Pair Instance Count: 325439


In [6]:
from deeprank import select_module
from deeprank import rank_module

select_module = reload(select_module)
rank_module = reload(rank_module)

In [7]:
select_net = select_module.IdentityNet(config=letor_config)
select_net.train()

IdentityNet()

In [8]:
letor_config['simmat_channel'] = 1
letor_config['conv_params'] = [(8, 2, 10)]
letor_config['fc_params'] = [50]
letor_config['dpool_size'] = [3, 10]
letor_config['lr'] = 0.001
letor_config['finetune_embed'] = False
rank_net = rank_module.MatchPyramidNet(config=letor_config)
rank_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
rank_net.train()
optimizer = optim.Adam(rank_net.parameters(), lr=letor_config['lr'])

In [9]:
for i in range(300):
    X1, X1_len, X2, X2_len, Y, F = pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)
    X1, X2, X1_len, X2_len = select_net(X1, X2, X1_len, X2_len)
    output = rank_net(X1, X2, X1_len, X2_len)
    loss = rank_net.pair_loss(output, Y)
    print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

1.0000371932983398
1.0046019554138184
1.0072566270828247
1.0112677812576294
1.0037473440170288
0.9999621510505676
1.00369393825531
0.9890027046203613
0.9978377819061279
0.9909675717353821
1.0006822347640991
1.0122286081314087
1.0002899169921875
1.0096266269683838
0.9986898899078369
1.0043926239013672
1.0029313564300537
0.9921618103981018
0.9960930347442627
0.9940978288650513
0.9959312081336975
0.994845986366272
1.0009658336639404
0.9946739673614502
1.0121084451675415
0.9943233728408813
1.0034382343292236
1.0075438022613525
0.9976112246513367
0.9931148290634155
1.002960443496704
1.0069024562835693
1.004291296005249
0.9950054883956909
1.0024023056030273
0.993139922618866
0.9710673093795776
1.0031425952911377
0.9864755868911743
1.0035897493362427
0.9892934560775757
0.9935711026191711
0.9907528162002563
0.9925571084022522
0.9940065741539001
0.9865570068359375
0.993522584438324
0.9920896291732788
0.9883055686950684
1.0044223070144653
0.9981510043144226
0.9896656274795532
0.9869481325149536


KeyboardInterrupt: 

In [19]:
X1, X1_len, X2, X2_len, Y, F = pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)

In [18]:
print(X1.shape, X1_len.shape, X2.shape, X2_len.shape, Y.shape, F.shape)

(128, 20) (128,) (128, 500) (128,) (128,) (128, 0)


In [22]:
print(X1.shape, X1_len.shape, X2.shape, X2_len.shape, Y.shape, F.shape)

torch.Size([128, 20]) torch.Size([128]) torch.Size([128, 200]) torch.Size([128]) (128,) (128, 0)


In [21]:
X1, X2, X1_len, X2_len = select_net(X1, X2, X1_len, X2_len)

In [32]:
X1[9]

tensor([   119,    120,      4, 193367, 193367, 193367, 193367, 193367, 193367,
        193367, 193367, 193367, 193367, 193367, 193367, 193367, 193367, 193367,
        193367, 193367])