In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from importlib import reload 

In [3]:
from deeprank.dataset import DataLoader, PairGenerator, ListGenerator
from deeprank import utils

In [4]:
seed = 1234
torch.manual_seed(seed)

<torch._C.Generator at 0x7fcd38708b10>

In [5]:
loader = DataLoader('./config/letor07_mp_fold1.model')

[./data/letor/r5w/word_dict.txt]
	Word dict size: 193367
[./data/letor/r5w/qid_query.txt]
	Data size: 1692
[./data/letor/r5w/docid_doc.txt]
	Data size: 65323
[./data/letor/r5w/embed_wiki-pdc_d50_norm]
	Embedding size: 109282
Generate numpy embed: (193368, 50)


In [6]:
import json
letor_config = json.loads(open('./config/letor07_mp_fold1.model').read())
#device = torch.device("cuda")
device = torch.device("cpu")

In [7]:
Letor07Path = letor_config['data_dir']

letor_config['fill_word'] = loader._PAD_
letor_config['embedding'] = loader.embedding
letor_config['feat_size'] = loader.feat_size
letor_config['vocab_size'] = loader.embedding.shape[0]
letor_config['embed_dim'] = loader.embedding.shape[1]
letor_config['pad_value'] = loader._PAD_

pair_gen = PairGenerator(rel_file=Letor07Path + '/relation.train.fold%d.txt'%(letor_config['fold']), 
                         config=letor_config)

[./data/letor/r5w/relation.train.fold1.txt]
	Instance size: 47828
Pair Instance Count: 325439


In [8]:
from deeprank import select_module
from deeprank import rank_module

In [9]:
letor_config['q_limit'] = 20
letor_config['d_limit'] = 2000
letor_config['max_match'] = 20
letor_config['win_size'] = 5
letor_config['finetune_embed'] = True
letor_config['lr'] = 0.0001
select_net = select_module.PointerNet(config=letor_config)
select_net = select_net.to(device)
select_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
select_net.train()
select_optimizer = optim.RMSprop(select_net.parameters(), lr=letor_config['lr'])

In [10]:
letor_config['simmat_channel'] = 1
letor_config['conv_params'] = [(8, 5, 5)]
letor_config['fc_params'] = []
letor_config['dpool_size'] = [3, 10]
letor_config['lr'] = 0.005
letor_config['finetune_embed'] = False
rank_net = rank_module.MatchPyramidNet(config=letor_config)
rank_net = rank_net.to(device)
rank_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
rank_net.train()
rank_optimizer = optim.Adam(rank_net.parameters(), lr=letor_config['lr'])

In [11]:
def to_device(*variables):
    return (torch.from_numpy(variable).to(device) for variable in variables)

In [12]:
def show_text(x):
    print(' '.join([loader.word_dict[w.item()] for w in x]))

In [13]:
# X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F = \
#         pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)
# X1, X1_len, X2, X2_len, Y, F = \
#         to_device(X1, X1_len, X2, X2_len, Y, F)
# show_text(X2[0])
# X1, X2_new, X1_len, X2_len_new = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
# show_text(X1[0])
# for i in range(5):
#     print(i, end=' ')
#     show_text(X2_new[0][i])

In [14]:
# X1 = X1[:1]
# X1_len = X1_len[:1]
# X2 = X2[:1]
# X2_len = X2_len[:1]
# X1_id = X1_id[:1]
# X2_id = X2_id[:1]

In [15]:
# show_text(X2[0])
# X1, X2_new, X1_len, X2_len_new = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
# show_text(X1[0])
# for i in range(5):
#     print(i, end=' ')
#     show_text(X2_new[0][i])

In [16]:
import time
start_t = time.time()
for i in range(450):
    # One Step Forward
    X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F = \
        pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)
    X1, X1_len, X2, X2_len, Y, F = \
        to_device(X1, X1_len, X2, X2_len, Y, F)
    X1, X2, X1_len, X2_len = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
    X2, X2_len = utils.data_adaptor(X2, X2_len, select_net, rank_net, letor_config)
    output = rank_net(X1, X2, X1_len, X2_len, 0)
    reward = rank_net.pair_reward(output, mode=0)
    
    # Update Rank Net
    rank_loss = rank_net.pair_loss(output, Y)
    print('rank loss:', rank_loss.item())
    rank_optimizer.zero_grad()
    rank_loss.backward()
    rank_optimizer.step()
    
    # Update Select Net
    select_loss = select_net.loss(reward)
    print('select loss:', select_loss.item())
    select_optimizer.zero_grad()
    select_loss.backward()
    select_optimizer.step()
    
end_t = time.time()
print('Time Cost: %s s' % (end_t-start_t))

rank loss: 1.0111258029937744
select loss: 51.4038200378418
rank loss: 0.9823170304298401
select loss: -79.29951477050781
rank loss: 0.9718998074531555
select loss: -27.89471435546875
rank loss: 0.959214985370636
select loss: -69.80249786376953
rank loss: 0.9314579963684082
select loss: -74.5136947631836
rank loss: 0.8925216794013977
select loss: -79.93017578125
rank loss: 0.8344653844833374
select loss: -88.05590057373047
rank loss: 0.8918383717536926
select loss: -51.391963958740234
rank loss: 0.7842301726341248
select loss: -74.7845230102539
rank loss: 0.8350582718849182
select loss: -102.76737976074219
rank loss: 0.8012187480926514
select loss: -65.6501693725586
rank loss: 0.7545879483222961
select loss: -97.87397766113281
rank loss: 0.9591776132583618
select loss: -36.935096740722656
rank loss: 0.7639358639717102
select loss: -56.21630859375
rank loss: 0.7523254156112671
select loss: -116.61343383789062
rank loss: 0.8277143239974976
select loss: -74.69215393066406
rank loss: 1.000

select loss: -67.86016845703125
rank loss: 0.6045598983764648
select loss: -137.84571838378906
rank loss: 0.7201898097991943
select loss: -91.00216674804688
rank loss: 0.8189139366149902
select loss: -65.8313980102539
rank loss: 0.7066715955734253
select loss: -46.81977081298828
rank loss: 0.8612008094787598
select loss: -28.019329071044922
rank loss: 0.8388642072677612
select loss: -70.15685272216797
rank loss: 0.77521151304245
select loss: -62.716026306152344
rank loss: 0.7203966379165649
select loss: -79.6053237915039
rank loss: 0.6798901557922363
select loss: -111.943359375
rank loss: 0.6476219892501831
select loss: -105.19648742675781
rank loss: 0.8216173648834229
select loss: -39.673439025878906
rank loss: 0.6259744167327881
select loss: -142.5208740234375
rank loss: 0.6977403163909912
select loss: -81.45513916015625
rank loss: 0.7512609958648682
select loss: -65.55528259277344
rank loss: 0.6839714050292969
select loss: -95.92532348632812
rank loss: 0.6863589286804199
select loss

rank loss: 0.7152947187423706
select loss: -88.991943359375
rank loss: 0.6973288655281067
select loss: -88.56939697265625
rank loss: 0.7197948694229126
select loss: -86.3812026977539
rank loss: 0.5923389196395874
select loss: -123.95072937011719
rank loss: 0.8021805286407471
select loss: -67.90987396240234
rank loss: 0.6344258189201355
select loss: -119.424072265625
rank loss: 0.8237828612327576
select loss: -56.22405242919922
rank loss: 0.6606104373931885
select loss: -126.36861419677734
rank loss: 0.7399793863296509
select loss: -74.97708892822266
rank loss: 0.84929358959198
select loss: -86.64076232910156
rank loss: 0.736971378326416
select loss: -84.11141967773438
rank loss: 0.5642281174659729
select loss: -130.98094177246094
rank loss: 0.721390962600708
select loss: -112.62434387207031
rank loss: 0.81866455078125
select loss: -51.54115295410156
rank loss: 0.779106616973877
select loss: -67.90081787109375
rank loss: 0.7958600521087646
select loss: -67.72911834716797
rank loss: 0.62

select loss: -93.91915893554688
rank loss: 0.5992094874382019
select loss: -107.67223358154297
rank loss: 0.5176311135292053
select loss: -126.38518524169922
rank loss: 0.6462628245353699
select loss: -88.95600891113281
rank loss: 0.6691640615463257
select loss: -98.78107452392578
rank loss: 0.6022146344184875
select loss: -112.77884674072266
rank loss: 0.7653400897979736
select loss: -89.23381805419922
rank loss: 0.722693681716919
select loss: -93.69613647460938
rank loss: 0.941752016544342
select loss: -23.53929901123047
rank loss: 0.7670688629150391
select loss: -98.39593505859375
rank loss: 0.7463536262512207
select loss: -110.03741455078125
rank loss: 0.7603667974472046
select loss: -74.91754150390625
rank loss: 0.7662746906280518
select loss: -93.69799041748047
rank loss: 0.6000296473503113
select loss: -119.4875717163086
rank loss: 0.6907034516334534
select loss: -70.47405242919922
rank loss: 0.6462081074714661
select loss: -84.94280242919922
rank loss: 0.6752679944038391
select

In [17]:
torch.save(select_net.state_dict(), "identity.ckpt")
torch.save(rank_net.state_dict(), "matchpyramid.ckpt")

In [18]:
torch.save(select_net, "identity.model")
torch.save(rank_net, "matchpyramid.model")

In [19]:
rank_net

MatchPyramidNet(
  (embedding): Embedding(193368, 50, padding_idx=193367)
  (conv_sequential): Sequential(
    (0): Conv2d(1, 8, kernel_size=[5, 5], stride=(1, 1), padding=[2, 2])
  )
  (dpool_layer): AdaptiveMaxPool2d(output_size=[3, 10])
  (fc_sequential): Sequential()
  (out_layer): Linear(in_features=240, out_features=1, bias=True)
)

In [20]:
select_net_e = torch.load(f='identity.model')
rank_net_e = torch.load(f='matchpyramid.model')

list_gen = ListGenerator(rel_file=Letor07Path+'/relation.test.fold%d.txt'%(letor_config['fold']),
                         config=letor_config)
map_v = 0.0
map_c = 0.0

with torch.no_grad():
    for X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F in \
        list_gen.get_batch(data1=loader.query_data, data2=loader.doc_data):
        #print(X1.shape, X2.shape, Y.shape)
        X1, X1_len, X2, X2_len, Y, F = to_device(X1, X1_len, X2, X2_len, Y, F)
        X1, X2, X1_len, X2_len = select_net_e(X1, X2, X1_len, X2_len, X1_id, X2_id)
        X2, X2_len = utils.data_adaptor(X2, X2_len, select_net, rank_net, letor_config)
        #print(X1.shape, X2.shape, Y.shape)
        pred = rank_net_e(X1, X2, X1_len, X2_len, 0)
        map_o = utils.eval_MAP(pred.tolist(), Y.tolist())
        #print(pred.shape, Y.shape)
        map_v += map_o
        map_c += 1.0
    map_v /= map_c

print('[Test]', map_v)

[./data/letor/r5w/relation.test.fold1.txt]
	Instance size: 13652
List Instance Count: 336
[Test] 0.4359386539000405
