In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from importlib import reload 

In [3]:
from deeprank.dataset import DataLoader, PairGenerator, ListGenerator
from deeprank import utils

In [4]:
seed = 1234
torch.manual_seed(seed)

<torch._C.Generator at 0x7f2c840dbaf0>

In [5]:
loader = DataLoader('./config/letor07_mp_fold1.model')

[./data/letor/r5w/word_dict.txt]
	Word dict size: 193367
[./data/letor/r5w/qid_query.txt]
	Data size: 1692
[./data/letor/r5w/docid_doc.txt]
	Data size: 65323
[./data/letor/r5w/embed_wiki-pdc_d50_norm]
	Embedding size: 109282
Generate numpy embed: (193368, 50)


In [6]:
import json
letor_config = json.loads(open('./config/letor07_mp_fold1.model').read())
#device = torch.device("cuda")
device = torch.device("cpu")

In [7]:
Letor07Path = letor_config['data_dir']

letor_config['fill_word'] = loader._PAD_
letor_config['embedding'] = loader.embedding
letor_config['feat_size'] = loader.feat_size
letor_config['vocab_size'] = loader.embedding.shape[0]
letor_config['embed_dim'] = loader.embedding.shape[1]
letor_config['pad_value'] = loader._PAD_

pair_gen = PairGenerator(rel_file=Letor07Path + '/relation.train.fold%d.txt'%(letor_config['fold']), 
                         config=letor_config)

[./data/letor/r5w/relation.train.fold1.txt]
	Instance size: 47828
Pair Instance Count: 325439


In [8]:
from deeprank import select_module
from deeprank import rank_module

In [9]:
select_net = select_module.IdentityNet(config=letor_config)
select_net.train()
select_net = select_net.to(device)

In [10]:
letor_config['q_limit'] = 20
letor_config['d_limit'] = 500
letor_config['max_match'] = 5
letor_config['win_size'] = 5
select_net = select_module.QueryCentricNet(config=letor_config)
select_net.train()
select_net = select_net.to(device)

In [11]:
# letor_config['simmat_channel'] = 1
# letor_config['conv_params'] = [(8, 2, 10)]
# letor_config['fc_params'] = [50]
# letor_config['dpool_size'] = [3, 10]
# letor_config['lr'] = 0.001
# letor_config['finetune_embed'] = False
# rank_net = rank_module.MatchPyramidNet(config=letor_config)
# rank_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
# rank_net.train()
# optimizer = optim.Adam(rank_net.parameters(), lr=letor_config['lr'])

In [12]:
letor_config['simmat_channel'] = 1
letor_config['conv_params'] = [(8, 3, 3)]
letor_config['fc_params'] = [200]
letor_config['dpool_size'] = [3, 10]
letor_config['lr'] = 0.001
letor_config['finetune_embed'] = False
rank_net = rank_module.MatchPyramidNet(config=letor_config)
rank_net = rank_net.to(device)
rank_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
rank_net.train()
optimizer = optim.Adam(rank_net.parameters(), lr=letor_config['lr'])

In [13]:
def to_device(*variables):
    return (torch.from_numpy(variable).to(device) for variable in variables)

In [14]:
import time
start_t = time.time()
for i in range(150):
    X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F = \
        pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)
    X1, X1_len, X2, X2_len, Y, F = \
        to_device(X1, X1_len, X2, X2_len, Y, F)
    X1, X2, X1_len, X2_len = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
    X2, X2_len = utils.data_adaptor(X2, X2_len, select_net, rank_net, letor_config)
    output = rank_net(X1, X2, X1_len, X2_len, 0)
    loss = rank_net.pair_loss(output, Y)
    print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
end_t = time.time()
print('Time Cost: %s s' % (end_t-start_t))

1.0005669593811035
0.9666146636009216
0.9545133113861084
0.9290570020675659
0.9250833988189697
0.9516744017601013
0.9926365613937378
0.9482710361480713
0.8718242645263672
0.8936516046524048
0.8728057146072388
0.8509190082550049
0.9060134291648865
0.90357905626297
0.8386606574058533
0.9735718965530396
0.8126947283744812
0.8493989706039429
0.8504531383514404
0.8859429359436035
0.8048561215400696
0.9653592705726624
1.0039255619049072
0.7623535394668579
0.8342338800430298
0.9242572784423828
0.8864844441413879
0.8077096939086914
0.7404946088790894
1.0208207368850708
0.9004018306732178
0.822470486164093
0.9969710111618042
0.8135121464729309
0.87203049659729
0.8742700815200806
0.7792972326278687
0.8595035076141357
0.7425932288169861
0.7918533682823181
0.8069121837615967
0.7917323112487793
0.689857006072998
0.6483471989631653
0.9997082352638245
0.9095818400382996
0.8643707036972046
0.969832181930542
0.9153090119361877
0.9198892116546631
0.8066182136535645
0.8728025555610657
0.7627391219139099


In [15]:
torch.save(select_net.state_dict(), "identity.ckpt")
torch.save(rank_net.state_dict(), "matchpyramid.ckpt")

In [16]:
torch.save(select_net, "identity.model")
torch.save(rank_net, "matchpyramid.model")

In [17]:
rank_net

MatchPyramidNet(
  (embedding): Embedding(193368, 50, padding_idx=193367)
  (conv_sequential): Sequential(
    (0): Conv2d(1, 8, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
  )
  (dpool_layer): AdaptiveMaxPool2d(output_size=[3, 10])
  (fc_sequential): Sequential(
    (0): Linear(in_features=240, out_features=200, bias=True)
  )
  (out_layer): Linear(in_features=200, out_features=1, bias=True)
)

In [18]:
select_net_e = torch.load(f='identity.model')
rank_net_e = torch.load(f='matchpyramid.model')

list_gen = ListGenerator(rel_file=Letor07Path+'/relation.test.fold%d.txt'%(letor_config['fold']),
                         config=letor_config)
map_v = 0.0
map_c = 0.0

with torch.no_grad():
    for X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F in \
        list_gen.get_batch(data1=loader.query_data, data2=loader.doc_data):
        #print(X1.shape, X2.shape, Y.shape)
        X1, X1_len, X2, X2_len, Y, F = to_device(X1, X1_len, X2, X2_len, Y, F)
        X1, X2, X1_len, X2_len = select_net_e(X1, X2, X1_len, X2_len, X1_id, X2_id)
        X2, X2_len = utils.data_adaptor(X2, X2_len, select_net, rank_net, letor_config)
        #print(X1.shape, X2.shape, Y.shape)
        pred = rank_net_e(X1, X2, X1_len, X2_len, 0)
        map_o = utils.eval_MAP(pred.tolist(), Y.tolist())
        #print(pred.shape, Y.shape)
        map_v += map_o
        map_c += 1.0
    map_v /= map_c

print('[Test]', map_v)

[./data/letor/r5w/relation.test.fold1.txt]
	Instance size: 13652
List Instance Count: 336
[Test] 0.40943836617689755
