In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from importlib import reload 

In [3]:
from deeprank.dataset import DataLoader, PairGenerator, ListGenerator
from deeprank import utils

In [4]:
seed = 1234
torch.manual_seed(seed)

<torch._C.Generator at 0x7f1d081a5b30>

In [5]:
loader = DataLoader('./config/letor07_mp_fold1.model')

[./data/letor/r5w/word_dict.txt]
	Word dict size: 193367
[./data/letor/r5w/qid_query.txt]
	Data size: 1692
[./data/letor/r5w/docid_doc.txt]
	Data size: 65323
[./data/letor/r5w/embed_wiki-pdc_d50_norm]
	Embedding size: 109282
Generate numpy embed: (193368, 50)


In [6]:
import json
letor_config = json.loads(open('./config/letor07_mp_fold1.model').read())
#device = torch.device("cuda")
device = torch.device("cpu")

In [7]:
Letor07Path = letor_config['data_dir']

letor_config['fill_word'] = loader._PAD_
letor_config['embedding'] = loader.embedding
letor_config['feat_size'] = loader.feat_size
letor_config['vocab_size'] = loader.embedding.shape[0]
letor_config['embed_dim'] = loader.embedding.shape[1]
letor_config['pad_value'] = loader._PAD_

pair_gen = PairGenerator(rel_file=Letor07Path + '/relation.train.fold%d.txt'%(letor_config['fold']), 
                         config=letor_config)

[./data/letor/r5w/relation.train.fold1.txt]
	Instance size: 47828
Pair Instance Count: 325439


In [8]:
from deeprank import select_module
from deeprank import rank_module

In [9]:
letor_config['q_limit'] = 20
letor_config['d_limit'] = 2000
letor_config['max_match'] = 20
letor_config['win_size'] = 5
letor_config['finetune_embed'] = True
letor_config['lr'] = 0.0001
select_net = select_module.PointerNet(config=letor_config)
select_net = select_net.to(device)
select_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
select_net.train()
select_optimizer = optim.RMSprop(select_net.parameters(), lr=letor_config['lr'])

In [10]:
letor_config['simmat_channel'] = 1
letor_config['conv_params'] = [(8, 5, 5)]
letor_config['fc_params'] = []
letor_config['dpool_size'] = [3, 10]
letor_config['lr'] = 0.005
letor_config['finetune_embed'] = False
rank_net = rank_module.MatchPyramidNet(config=letor_config)
rank_net = rank_net.to(device)
rank_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
rank_net.train()
rank_optimizer = optim.Adam(rank_net.parameters(), lr=letor_config['lr'])

In [11]:
def to_device(*variables):
    return (torch.from_numpy(variable).to(device) for variable in variables)

In [12]:
def show_text(x):
    print(' '.join([loader.word_dict[w.item()] for w in x]))

In [13]:
# X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F = \
#         pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)
# X1, X1_len, X2, X2_len, Y, F = \
#         to_device(X1, X1_len, X2, X2_len, Y, F)
# show_text(X2[0])
# X1, X2_new, X1_len, X2_len_new = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
# show_text(X1[0])
# for i in range(5):
#     print(i, end=' ')
#     show_text(X2_new[0][i])

In [14]:
# X1 = X1[:1]
# X1_len = X1_len[:1]
# X2 = X2[:1]
# X2_len = X2_len[:1]
# X1_id = X1_id[:1]
# X2_id = X2_id[:1]

In [15]:
# show_text(X2[0])
# X1, X2_new, X1_len, X2_len_new = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
# show_text(X1[0])
# for i in range(5):
#     print(i, end=' ')
#     show_text(X2_new[0][i])

In [21]:
import time
start_t = time.time()
for i in range(150):
    # One Step Forward
    X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F = \
        pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)
    X1, X1_len, X2, X2_len, Y, F = \
        to_device(X1, X1_len, X2, X2_len, Y, F)
    X1, X2, X1_len, X2_len = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
    X2, X2_len = utils.data_adaptor(X2, X2_len, select_net, rank_net, letor_config)
    output = rank_net(X1, X2, X1_len, X2_len, 0)
    reward = rank_net.pair_reward(output)
    
    # Update Rank Net
    rank_loss = rank_net.pair_loss(output, Y)
    print('rank loss:', rank_loss.item())
    rank_optimizer.zero_grad()
    rank_loss.backward()
    rank_optimizer.step()
    
    # Update Select Net
    select_loss = select_net.loss(reward)
    print('select loss:', select_loss.item())
    select_optimizer.zero_grad()
    select_loss.backward()
    select_optimizer.step()
    
end_t = time.time()
print('Time Cost: %s s' % (end_t-start_t))

rank loss: 0.748028039932251
select loss: -81.83063507080078
rank loss: 0.6956724524497986
select loss: -84.35591125488281
rank loss: 0.7158995866775513
select loss: -74.6168441772461
rank loss: 0.7266180515289307
select loss: -74.89917755126953
rank loss: 0.7044918537139893
select loss: -86.19282531738281
rank loss: 0.7713779211044312
select loss: -51.628108978271484
rank loss: 0.7648971080780029
select loss: -65.54090118408203
rank loss: 0.675477921962738
select loss: -107.93647766113281
rank loss: 0.680754542350769
select loss: -93.41030883789062
rank loss: 0.6644151210784912
select loss: -98.0960464477539
rank loss: 0.695019006729126
select loss: -84.1122817993164
rank loss: 0.7311475872993469
select loss: -86.5036849975586
rank loss: 0.6501885056495667
select loss: -98.3669662475586
rank loss: 0.661821722984314
select loss: -86.73979187011719
rank loss: 0.5587347149848938
select loss: -131.02963256835938
rank loss: 0.7607327103614807
select loss: -84.34652709960938
rank loss: 0.69

select loss: -56.31305694580078
rank loss: 0.7258126735687256
select loss: -82.53302764892578
rank loss: 0.7282374501228333
select loss: -56.28727722167969
rank loss: 0.7001456618309021
select loss: -88.99053955078125
rank loss: 0.7024471759796143
select loss: -75.17156219482422
rank loss: 0.6416345238685608
select loss: -112.61994934082031
rank loss: 0.7538624405860901
select loss: -62.988216400146484
rank loss: 0.759031355381012
select loss: -96.0082015991211
rank loss: 0.8021743893623352
select loss: -65.6800308227539
rank loss: 0.6040209531784058
select loss: -133.02374267578125
rank loss: 0.5320473909378052
select loss: -133.5270538330078
rank loss: 0.7553613185882568
select loss: -61.046348571777344
rank loss: 0.9230668544769287
select loss: -56.47019577026367
rank loss: 0.7782972455024719
select loss: -42.31257629394531
rank loss: 0.7545278072357178
select loss: -63.388450622558594
rank loss: 0.6015785336494446
select loss: -107.89686584472656
rank loss: 0.6541085839271545
selec

In [22]:
torch.save(select_net.state_dict(), "identity.ckpt")
torch.save(rank_net.state_dict(), "matchpyramid.ckpt")

In [23]:
torch.save(select_net, "identity.model")
torch.save(rank_net, "matchpyramid.model")

In [24]:
rank_net

MatchPyramidNet(
  (embedding): Embedding(193368, 50, padding_idx=193367)
  (conv_sequential): Sequential(
    (0): Conv2d(1, 8, kernel_size=[5, 5], stride=(1, 1), padding=[2, 2])
  )
  (dpool_layer): AdaptiveMaxPool2d(output_size=[3, 10])
  (fc_sequential): Sequential()
  (out_layer): Linear(in_features=240, out_features=1, bias=True)
)

In [None]:
select_net_e = torch.load(f='identity.model')
rank_net_e = torch.load(f='matchpyramid.model')

list_gen = ListGenerator(rel_file=Letor07Path+'/relation.test.fold%d.txt'%(letor_config['fold']),
                         config=letor_config)
map_v = 0.0
map_c = 0.0

with torch.no_grad():
    for X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F in \
        list_gen.get_batch(data1=loader.query_data, data2=loader.doc_data):
        #print(X1.shape, X2.shape, Y.shape)
        X1, X1_len, X2, X2_len, Y, F = to_device(X1, X1_len, X2, X2_len, Y, F)
        X1, X2, X1_len, X2_len = select_net_e(X1, X2, X1_len, X2_len, X1_id, X2_id)
        X2, X2_len = utils.data_adaptor(X2, X2_len, select_net, rank_net, letor_config)
        #print(X1.shape, X2.shape, Y.shape)
        pred = rank_net_e(X1, X2, X1_len, X2_len, 0)
        map_o = utils.eval_MAP(pred.tolist(), Y.tolist())
        #print(pred.shape, Y.shape)
        map_v += map_o
        map_c += 1.0
    map_v /= map_c

print('[Test]', map_v)

[./data/letor/r5w/relation.test.fold1.txt]
	Instance size: 13652
List Instance Count: 336
