In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from importlib import reload 

In [3]:
from deeprank.dataset import DataLoader, PairGenerator, ListGenerator
from deeprank import utils

In [4]:
seed = 1234
torch.manual_seed(seed)

<torch._C.Generator at 0x7f4b85d28a90>

In [5]:
loader = DataLoader('./config/letor07_mp_fold1.model')

[./data/letor/r5w/word_dict.txt]
	Word dict size: 193367
[./data/letor/r5w/qid_query.txt]
	Data size: 1692
[./data/letor/r5w/docid_doc.txt]
	Data size: 65323
[./data/letor/r5w/embed_wiki-pdc_d50_norm]
	Embedding size: 109282
Generate numpy embed: (193368, 50)


In [6]:
import json
letor_config = json.loads(open('./config/letor07_mp_fold1.model').read())
#device = torch.device("cuda")
#device = torch.device("cpu")
select_device = torch.device("cpu")
rank_device = torch.device("cuda")

In [7]:
Letor07Path = letor_config['data_dir']

letor_config['fill_word'] = loader._PAD_
letor_config['embedding'] = loader.embedding
letor_config['feat_size'] = loader.feat_size
letor_config['vocab_size'] = loader.embedding.shape[0]
letor_config['embed_dim'] = loader.embedding.shape[1]
letor_config['pad_value'] = loader._PAD_

pair_gen = PairGenerator(rel_file=Letor07Path + '/relation.train.fold%d.txt'%(letor_config['fold']), 
                         config=letor_config)

[./data/letor/r5w/relation.train.fold1.txt]
	Instance size: 47828
Pair Instance Count: 325439


In [8]:
from deeprank import select_module
from deeprank import rank_module

In [9]:
letor_config['max_match'] = 20
letor_config['win_size'] = 5
select_net = select_module.QueryCentricNet(config=letor_config, out_device=rank_device)
select_net = select_net.to(select_device)
select_net.train()

QueryCentricNet()

In [10]:
'''
letor_config['q_limit'] = 20
letor_config['d_limit'] = 2000
letor_config['max_match'] = 20
letor_config['win_size'] = 5
letor_config['finetune_embed'] = True
letor_config['lr'] = 0.0001
select_net = select_module.PointerNet(config=letor_config)
select_net = select_net.to(device)
select_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
select_net.train()
select_optimizer = optim.RMSprop(select_net.parameters(), lr=letor_config['lr'])
'''

"\nletor_config['q_limit'] = 20\nletor_config['d_limit'] = 2000\nletor_config['max_match'] = 20\nletor_config['win_size'] = 5\nletor_config['finetune_embed'] = True\nletor_config['lr'] = 0.0001\nselect_net = select_module.PointerNet(config=letor_config)\nselect_net = select_net.to(device)\nselect_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))\nselect_net.train()\nselect_optimizer = optim.RMSprop(select_net.parameters(), lr=letor_config['lr'])\n"

In [11]:
letor_config["dim_q"] = 1
letor_config["dim_d"] = 1
letor_config["dim_weight"] = 1
letor_config["c_reduce"] = [1, 1]
letor_config["k_reduce"] = [1, 50]
letor_config["s_reduce"] = 1
letor_config["p_reduce"] = [0, 0]

letor_config["c_en_conv_out"] = 4
letor_config["k_en_conv"] = 3
letor_config["s_en_conv"] = 1
letor_config["p_en_conv"] = 1

letor_config["en_pool_out"] = [1, 1]
letor_config["en_leaky"] = 0.2

letor_config["dim_gru_hidden"] = 3

letor_config['lr'] = 0.005
letor_config['finetune_embed'] = False

rank_net = rank_module.DeepRankNet(config=letor_config)
rank_net = rank_net.to(rank_device)
rank_net.embedding.weight.data.copy_(torch.from_numpy(loader.embedding))
rank_net.train()
rank_optimizer = optim.Adam(rank_net.parameters(), lr=letor_config['lr'])

In [12]:
def to_device(*variables, device):
    return (torch.from_numpy(variable).to(device) for variable in variables)

In [13]:
def show_text(x):
    print(' '.join([loader.word_dict[w.item()] for w in x]))

In [14]:
X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F = \
        pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)
X1, X1_len, X2, X2_len, Y, F = \
        to_device(X1, X1_len, X2, X2_len, Y, F, device=rank_device)

show_text(X2[0])

X1, X2_new, X1_len, X2_len_new, X2_pos = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)

show_text(X1[0])
for i in range(5):
    print(i, end=' ')
    show_text(X2_new[0][i])

coin vend amusement machine service repair occupational outlook handbook edition department labor bureau labor statistics bulletin coin vend amusement machine service repair nature work working conditions employment training qualifications advancement job outlook earnings related occupation source additional information significant points most worker learn skill job opportunity good person knowledge electronics nature work section back top coin vend amusement machine familiar sight offices convenience stores arcade casino coin operate machine give change dispense refreshments test senses spit lottery ticket nearly turn coin vend amusement machine service repair install service stock machine keep good working order vend machine service call route driver visit machine dispense soft drink candy snack item collect money machine restock merchandise change label indicate new selection keep machine clean appealing vend machine repair call mechanics technician make sure machine operate correct

slot machine malfunction $$ [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
0 jukebox video games pinball machine slot machine make sure various lever
1 video games pinball machine jukebox slot machine similar type amusement equipment
2 [PAD] [PAD] coin vend amusement machine service repair occupational outlook handbook
3 statistics bulletin coin vend amusement machine service repair nature work working
4 back top coin vend amusement machine familiar sight offices convenience stores


In [15]:
print(X2_pos)

[tensor([215., 546.,   3.,  19.,  58.,  68.,  83.,  89.,  95., 101., 110., 119.,
        123., 130., 137., 142., 155., 169., 171., 179., 191., 196., 235., 284.,
        333.], device='cuda:0'), tensor([ 0., 11.,  1.], device='cuda:0'), tensor([347., 349., 139., 353., 386.], device='cuda:0'), tensor([ 15., 273., 294., 322., 350., 356., 383., 393., 405., 457., 506., 520.,
        571., 575., 599., 610., 612., 684., 691., 711., 712., 121.],
       device='cuda:0'), tensor([339., 346., 367., 380., 391.,  19.,  61.,  71., 149., 177., 208., 225.,
        503., 610.], device='cuda:0'), tensor([ 47., 327., 369., 407.,  58., 184., 260., 281., 333., 384., 435., 452.],
       device='cuda:0'), tensor([  2.,  19.,  21.,  25.,  36., 111.,   1.,  18.,  20.,  28.,  34.,  35.,
        110.], device='cuda:0'), tensor([ 3., 14.,  1., 12., 19., 49.], device='cuda:0'), tensor([  1.,   8.,  10.,  21.,  34.,  49.,  69.,  74., 110., 120., 148., 172.,
        192., 208., 254., 267., 279., 293., 297.,   2.,   

In [16]:
# X1 = X1[:1]
# X1_len = X1_len[:1]
# X2 = X2[:1]
# X2_len = X2_len[:1]
# X1_id = X1_id[:1]
# X2_id = X2_id[:1]

In [17]:
# show_text(X2[0])
# X1, X2_new, X1_len, X2_len_new = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
# show_text(X1[0])
# for i in range(5):
#     print(i, end=' ')
#     show_text(X2_new[0][i])

In [18]:
import time
start_t = time.time()
for i in range(1000):
    # One Step Forward
    X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F = \
        pair_gen.get_batch(data1=loader.query_data, data2=loader.doc_data)
    X1, X1_len, X2, X2_len, Y, F = \
        to_device(X1, X1_len, X2, X2_len, Y, F, device=select_device)
    X1, X2, X1_len, X2_len, X2_pos = select_net(X1, X2, X1_len, X2_len, X1_id, X2_id)
    X2, X2_len = utils.data_adaptor(X2, X2_len, select_net, rank_net, letor_config)
    output = rank_net(X1, X2, X1_len, X2_len, X2_pos)
    
    # Update Rank Net
    rank_loss = rank_net.pair_loss(output, Y)
    print('rank loss:', rank_loss.item())
    rank_optimizer.zero_grad()
    rank_loss.backward()
    rank_optimizer.step()
    
end_t = time.time()
print('Time Cost: %s s' % (end_t-start_t))

rank loss: 0.9951812624931335
rank loss: 1.0048960447311401
rank loss: 0.9940654039382935
rank loss: 1.0042060613632202
rank loss: 1.0156229734420776
rank loss: 1.0219076871871948
rank loss: 1.0000731945037842
rank loss: 1.0086194276809692
rank loss: 0.9833398461341858
rank loss: 0.9997835159301758
rank loss: 1.001913070678711
rank loss: 1.0131293535232544
rank loss: 1.0235469341278076
rank loss: 1.0080866813659668
rank loss: 1.0018788576126099
rank loss: 0.9935066103935242
rank loss: 0.9837742447853088
rank loss: 0.9826953411102295
rank loss: 1.001282811164856
rank loss: 0.9963813424110413
rank loss: 1.002352237701416
rank loss: 0.9916420578956604
rank loss: 0.9915006756782532
rank loss: 0.9980853199958801
rank loss: 0.9874790906906128
rank loss: 1.0032700300216675
rank loss: 0.9756821990013123
rank loss: 0.9791268706321716
rank loss: 1.0030436515808105
rank loss: 1.0126398801803589
rank loss: 1.0038189888000488
rank loss: 1.0133439302444458
rank loss: 1.0059489011764526
rank loss: 1.

rank loss: 0.8080474138259888
rank loss: 0.8481745719909668
rank loss: 0.8028175234794617
rank loss: 0.934680700302124
rank loss: 0.8829318284988403
rank loss: 0.7785511016845703
rank loss: 0.8013278841972351
rank loss: 0.8055276870727539
rank loss: 0.6992752552032471
rank loss: 0.6995631456375122
rank loss: 0.7595901489257812
rank loss: 0.8387193083763123
rank loss: 0.8899760842323303
rank loss: 0.8644739985466003
rank loss: 0.7523307800292969
rank loss: 0.8586511015892029
rank loss: 0.8151028156280518
rank loss: 0.7556069493293762
rank loss: 0.7688696384429932
rank loss: 0.785063624382019
rank loss: 0.7328717112541199
rank loss: 0.8722243905067444
rank loss: 0.7965525388717651
rank loss: 0.7584728598594666
rank loss: 0.7989668250083923
rank loss: 0.7456743717193604
rank loss: 0.7686604261398315
rank loss: 0.8510354161262512
rank loss: 0.8256372809410095
rank loss: 0.8128499388694763
rank loss: 0.7991430759429932
rank loss: 0.7391169667243958
rank loss: 0.8185761570930481
rank loss: 0

rank loss: 0.737925112247467
rank loss: 0.6597367525100708
rank loss: 0.7327467203140259
rank loss: 0.770077645778656
rank loss: 0.6913379430770874
rank loss: 0.7471276521682739
rank loss: 0.6643778681755066
rank loss: 0.5819051265716553
rank loss: 0.8806478977203369
rank loss: 0.6706067323684692
rank loss: 0.6732643246650696
rank loss: 0.7268778681755066
rank loss: 0.7179862856864929
rank loss: 0.8077387809753418
rank loss: 0.7792778611183167
rank loss: 0.806076169013977
rank loss: 0.6893807649612427
rank loss: 0.793549656867981
rank loss: 0.7210186719894409
rank loss: 0.6901152729988098
rank loss: 0.7053525447845459
rank loss: 0.71075439453125
rank loss: 0.7472701072692871
rank loss: 0.767518162727356
rank loss: 0.6974484324455261
rank loss: 0.7199957370758057
rank loss: 0.744120180606842
rank loss: 0.8410906195640564
rank loss: 0.8157952427864075
rank loss: 0.7587652206420898
rank loss: 0.7769191861152649
rank loss: 0.7250230312347412
rank loss: 0.7740619778633118
rank loss: 0.71559

rank loss: 0.6410656571388245
rank loss: 0.7186697125434875
rank loss: 0.6708215475082397
rank loss: 0.7179554104804993
rank loss: 0.6138277053833008
rank loss: 0.7500765919685364
rank loss: 0.6288455724716187
rank loss: 0.7329901456832886
rank loss: 0.7868987917900085
rank loss: 0.6716737747192383
rank loss: 0.6971812844276428
rank loss: 0.8783232569694519
rank loss: 0.604184091091156
rank loss: 0.6997931003570557
rank loss: 0.7203173637390137
rank loss: 0.7120301723480225
rank loss: 0.7207215428352356
rank loss: 0.6839776635169983
rank loss: 0.6435736417770386
rank loss: 0.6985917687416077
rank loss: 0.7194593548774719
rank loss: 0.7111312747001648
rank loss: 0.7160836458206177
rank loss: 0.6214325428009033
rank loss: 0.7272852659225464
rank loss: 0.6682885885238647
rank loss: 0.5944069027900696
rank loss: 0.741405725479126
rank loss: 0.6204160451889038
rank loss: 0.6566944718360901
rank loss: 0.6274212598800659
rank loss: 0.8463570475578308
rank loss: 0.7300096154212952
rank loss: 0

In [19]:
torch.save(select_net, "qcentric.model")
torch.save(rank_net, "deeprank.model")

In [20]:
select_net_e = torch.load(f='qcentric.model')
rank_net_e = torch.load(f='deeprank.model')

list_gen = ListGenerator(rel_file=Letor07Path+'/relation.test.fold%d.txt'%(letor_config['fold']),
                         config=letor_config)
map_v = 0.0
map_c = 0.0

with torch.no_grad():
    for X1, X1_len, X1_id, X2, X2_len, X2_id, Y, F in \
        list_gen.get_batch(data1=loader.query_data, data2=loader.doc_data):
        print(X1.shape, X2.shape, Y.shape)
        X1, X1_len, X2, X2_len, Y, F = to_device(X1, X1_len, X2, X2_len, Y, F, device=select_device)
        X1, X2, X1_len, X2_len, X2_pos = select_net_e(X1, X2, X1_len, X2_len, X1_id, X2_id)
        X2, X2_len = utils.data_adaptor(X2, X2_len, select_net, rank_net, letor_config)
        #print(X1.shape, X2.shape, Y.shape)
        pred = rank_net_e(X1, X2, X1_len, X2_len, X2_pos)
        map_o = utils.eval_MAP(pred.tolist(), Y.tolist())
        #print(pred.shape, Y.shape)
        map_v += map_o
        map_c += 1.0
    map_v /= map_c

print('[Test]', map_v)

[./data/letor/r5w/relation.test.fold1.txt]
	Instance size: 13652
List Instance Count: 336
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(50, 20) (50, 2000) (50,)
(51, 20) (51, 2000) (51,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)


(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
(40, 20) (40, 2000) (40,)
[Test] 0.39690111471051653
