In [2]:
% matplotlib inline
import pandas as pd
import pickle as pkl
import string
import numpy as np; np.random.seed(7)
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
import re
import time

In [None]:
# global
HIDDEN_DIM = 120

In [3]:
w2v_map = {}
with open('data/vectors_pruned.200.txt', 'r') as src:
    src = src.read().strip().split('\n')
    for line in src:
        wv = line.strip().split(' ')
        word = wv.pop(0)
        w2v_map[word] = np.array(list(map(float, wv)))

w2i_map = {}
w2v_matrix = np.zeros(( len((w2v_map.keys())), 200 ))
for i, (key, val) in enumerate(w2v_map.items()):
    w2i_map[key] = i
    w2v_matrix[i] = val

def w2v(w):
    return w2v_matrix[w2i_map[w]]

def sen2w(sen):
    processed = []
    sen = sen.strip().split()
    if len(sen) > 100:
        sen = sen[:100]
    for w in sen:
        #ignore date
        if re.match(r'\d{1,}-\d{1,}-\d{1,}', w):
            continue
        if re.match(r'\d{1,}:\d{1,}', w):
            continue
        
        if w in w2i_map:
            processed += [w]
        else:
            separated = re.findall(r"[^\W\d_]+|\d+|[=`%$\^\-@;\[&_*>\].<~|+\d+]", w)
            if len(set(separated)) == 1:
                continue
            if separated.count('*') > 3 or separated.count('=') > 3:
                continue
            for separate_w in separated:
                if separate_w in w2i_map:
                    processed += [separate_w]
    return processed

In [4]:
# fixed context len = 125
context_repre = {}
with open('data/text_tokenized.txt', 'r') as src:
    src = src.read().strip().split('\n')
    for line in src:
        context = line.strip().split('\t')
        qid = context.pop(0)
        if len(context) == 1:
            context_repre[int(qid)] = {'t': sen2w(context[0]), 'b': None}
        else:
            context_repre[int(qid)] = {'t':sen2w(context[0]), 'b': sen2w(context[1])}

In [5]:
def build_set_pair_with_idx(df):
    idx_set = {}
    for idx, row in df.iterrows():
        idx_set[row['Q']] = {'pos': np.array(list(map(int, row['Q+'].split(' ')))), \
                             'neg': np.array(list(map(int, row['Q-'].split(' '))))}
    return idx_set

train_df = pd.read_csv('data/train_random.txt', header=None, delimiter='\t', names=['Q','Q+','Q-'])
train_idx_set = build_set_pair_with_idx(train_df)

In [6]:
train_df.head()

Unnamed: 0,Q,Q+,Q-
0,262144,211039,227387 413633 113297 356390 256881 145638 2962...
1,491522,65911,155119 402211 310669 383107 131731 299465 1633...
2,240299,168608 390642,368007 70009 48077 376760 438005 228888 142340...
3,196614,205184,334471 163710 376791 441664 159963 406360 4300...
4,360457,321532,151863 501857 217578 470017 125838 31836 42066...


In [8]:
def contxt2vec(title, body=None):
    
    if body == None:
        body = []
    
    title_v = np.zeros( (len(title), 200) )
    
    for i, t in enumerate(title):
        title_v[i] = w2v(t)
    
    if len(body) > 0:
        body_v = np.zeros( (len(body), 200) )
        for i, b in enumerate(body):
            body_v[i] = w2v(b)
    
        return title_v, body_v
    
    return title_v, None

In [25]:
def process_contxt_batch(qids, idx_set):
    
    batch_title, batch_body = [], []
    max_title_len, max_body_len = 0, 0
    title_len, body_len = [], []
    counter = 0
#     y = []
    
    for qid in qids:
        
        q_title, q_body = context_repre[qid]['t'], context_repre[qid]['b']
        q_pos = idx_set[qid]['pos']
        
        if len(q_pos) > 20:
            q_pos = q_pos[:20]

        for qid_pos in q_pos:

            # query Q
            title_len += [len(q_title)]
            batch_title += [ q_title ]
            max_title_len = max(max_title_len, len(q_title))
            if not q_body:
                body_len += [len(q_title)]
                batch_body += [ q_title ]
            else:
                batch_body += [ q_body ]
                body_len += [len(q_body)]
                max_body_len = max(max_body_len, len(q_body))
#             y += [1]
            # pos Q
            title, body = context_repre[qid_pos]['t'], context_repre[qid_pos]['b']
            title_len += [len(title)]
            batch_title += [ title ]
            max_title_len = max(max_title_len, len(title))
            if not body:
                body_len += [len(title)]
                batch_body += [ title ]
            else:
                batch_body += [ body ]
                body_len += [len(body)]
                max_body_len = max(max_body_len, len(body))
#             y += [1]
            # neg Q
            q_neg = idx_set[qid]['neg']
            q_neg_sample_indices = np.random.choice(range(100), size=20)
            q_random_neg = q_neg[q_neg_sample_indices]
            
            for qid_neg in q_random_neg:
                title, body = context_repre[qid_neg]['t'], context_repre[qid_neg]['b']
                title_len += [len(title)]
                batch_title += [ title ]
                max_title_len = max(max_title_len, len(title))
                if not body:
                    body_len += [len(title)]
                    batch_body += [ title ]
                else:
                    batch_body += [ body ]
                    body_len += [len(body)]
                    max_body_len = max(max_body_len, len(body))
#                 y += [0]
    # (max_seq_len, batch_size, feature_len)
    padded_batch_title = np.zeros(( max_title_len, len(batch_title), 200)) 
    padded_batch_body = np.zeros(( max_body_len, len(batch_body),  200))
    
    for i, (title, body) in enumerate(zip(batch_title, batch_body)):
        title_repre, body_repre = contxt2vec(title, body)
        padded_batch_title[:title_len[i], i] = title_repre
        padded_batch_body[:body_len[i], i] = body_repre
    #np.array(y).reshape(-1,1)
    return padded_batch_title, padded_batch_body, \
                np.array(title_len).reshape(-1,1), np.array(body_len).reshape(-1,1)

# Eval

In [137]:
def read_annotations(path, K_neg=20, prune_pos_cnt=10):
    lst = [ ]
    with open(path) as fin:
        for line in fin:
            parts = line.split("\t")
            pid, pos, neg = parts[:3]
            pos = pos.split()
            neg = neg.split()
            if len(pos) == 0 or (len(pos) > prune_pos_cnt and prune_pos_cnt != -1): continue
            if K_neg != -1:
                np.random.shuffle(neg)
                neg = neg[:K_neg]
            s = set()
            qids = [ ]
            qlabels = [ ]
            for q in neg:
                if q not in s:
                    qids.append(q)
                    qlabels.append(0 if q not in pos else 1)
                    s.add(q)
            for q in pos:
                if q not in s:
                    qids.append(q)
                    qlabels.append(1)
                    s.add(q)
            lst.append((pid, qids, qlabels))

    return lst

def cos_sim(qv, qv_):
    return torch.sum(qv * qv_, dim=1) / (torch.sqrt(torch.sum(qv ** 2, dim=1)) * torch.sqrt(torch.sum(qv_ ** 2, dim=1)))
    
# create eval batch 
def process_eval_batch(qid, data):
    qid_dict = data[qid]
    qs = qid_dict['q']
    max_title_len, max_body_len = 0, 0
    title_len, body_len = [], []
    batch_title, batch_body = [], []
    for qid_ in [qid] + qs:
        title, body = context_repre[qid_]['t'], context_repre[qid_]['b']
        title_len += [len(title)]
        batch_title += [ title ]
        max_title_len = max(max_title_len, len(title))
        if not body:
            body_len += [len(title)]
            batch_body += [ title ]
        else:
            batch_body += [ body ]
            body_len += [len(body)]
            max_body_len = max(max_body_len, len(body))
        
    padded_batch_title = np.zeros(( max_title_len, len(batch_title), 200)) 
    padded_batch_body = np.zeros(( max_body_len, len(batch_body),  200))
    
    for i, (title, body) in enumerate(zip(batch_title, batch_body)):
        title_repre, body_repre = contxt2vec(title, body)
        padded_batch_title[:title_len[i], i] = title_repre
        padded_batch_body[:body_len[i], i] = body_repre
    
    return padded_batch_title, padded_batch_body, \
           np.array(title_len).reshape(-1,1), np.array(body_len).reshape(-1,1) 
    
def evaluate(embeddings): # (n x 240)
    qs = embeddings[0]
    qs_ = embeddings[1:]
    cos_scores = cos_sim(qs.expand(len(embeddings)-1, qs.size(0)), qs_)
    return cos_scores

def precision(at, labels):
    res = []
    for item in labels:
        tmp = item[:at]
        if any(val==1 for val in item):
            res.append(np.sum(tmp) / len(tmp) if len(tmp) != 0 else 0.0)
    return sum(res)/len(res) if len(res) != 0 else 0.0

def MAP(labels):
    scores = []
    missing_MAP = 0
    for item in labels:
        temp = []
        count = 0.0
        for i,val in enumerate(item):
            
            if val == 1:
                count += 1.0
                temp.append(count/(i+1))
            if len(temp) > 0:
                scores.append(sum(temp) / len(temp))
            else:
                missing_MAP += 1
    return sum(scores)/len(scores) if len(scores) > 0 else 0.0
    
def MRR(labels):
    scores = []
    for item in labels:
        for i,val in enumerate(item):
            if val == 1:
                scores.append(1.0/(i+1))
                break
    return sum(scores)/len(scores) if len(scores) > 0 else 0.0

In [118]:
def do_eval(eval_map, eval_data, embedding_layer, eval_name):
    
    labels = []
    
    for qid_ in eval_map.keys():
        
        eval_title_batch, eval_body_batch, eval_title_len, eval_body_len = eval_map[qid_] # process_eval_batch(qid_, eval_data)
        embedding_layer.title_hidden = embedding_layer.init_hidden(eval_title_batch.shape[1])
        embedding_layer.body_hidden = embedding_layer.init_hidden(eval_body_batch.shape[1])
        eval_title_qs = Variable(torch.FloatTensor(eval_title_batch))
        eval_body_qs = Variable(torch.FloatTensor(eval_body_batch))
        embeddings = embedding_layer(eval_title_qs, eval_body_qs, eval_title_len, eval_body_len)
        cos_scores = evaluate(embeddings)
        labels.append(np.array(eval_data[qid_]['label'])[np.argsort(cos_scores.data.numpy())][::-1])

    print (eval_name + ' Performance P@5', precision(5, labels))
    print (eval_name + ' Performance P@1', precision(1, labels))
    print (eval_name + ' Performance MAP', MAP(labels))
    print (eval_name + ' Performance MRR', MRR(labels))

In [119]:
# DEV SET
dev = read_annotations('data/dev.txt')
dev_data = {}
for item in dev:
    qid = int(item[0])
    dev_data[qid] = {}
    dev_data[qid]['q'] = list(map(int, item[1]))
    dev_data[qid]['label'] = item[2]

# TEST SET
test = read_annotations('data/test.txt')
test_data = {}
for item in test:
    qid = int(item[0])
    test_data[qid] = {}
    test_data[qid]['q'] = list(map(int, item[1]))
    test_data[qid]['label'] = item[2]

In [120]:
dev_map = {}
for qid_ in dev_data.keys():
    dev_map[qid_] = process_eval_batch(qid_, dev_data)

test_map = {}
for qid_ in test_data.keys():
    test_map[qid_] = process_eval_batch(qid_, test_data)

# Train Utility

In [159]:
def build_mask(seq_len):
    mask = []
    for i, s in enumerate(seq_len):
        s_mask = np.zeros((np.max(seq_len), 1))
        s_mask[:int(s)] = np.ones((int(s), 1))
        mask += [s_mask]
    return mask

def build_mask3d(seq_len):
    mask = np.zeros((np.max(seq_len), len(seq_len), 1))
    for i, s in enumerate(seq_len):
        mask[:int(s), i] = np.ones((int(s), 1))
    return mask

def multi_margin_loss(margin=0.30):
    
    def loss_func(embeddings):
        # a batch of embeddings
        blocked_embeddings = embeddings.view(-1, 22, HIDDEN_DIM *2)
        q_vecs = blocked_embeddings[:,0,:]
        pos_vecs = blocked_embeddings[:,1,:]
        neg_vecs = blocked_embeddings[:,2:,:]

        pos_scores = torch.sum(q_vecs * pos_vecs, dim=1) / (torch.sqrt(torch.sum(q_vecs ** 2, dim=1)) \
                                                   * torch.sqrt(torch.sum(pos_vecs ** 2, dim=1)))

        neg_scores = torch.sum(torch.unsqueeze(q_vecs, dim=1) * neg_vecs, dim=2) \
        / (torch.unsqueeze(torch.sqrt(torch.sum(q_vecs ** 2, dim=1)),dim=1) * torch.sqrt(torch.sum( neg_vecs ** 2, dim=2)))
        neg_scores = torch.max(neg_scores, dim=1)[0]

        diff = neg_scores - pos_scores + margin
        loss = torch.mean((diff > 0).float() * diff)
        return loss

    return loss_func

# Model

In [156]:
class EmbeddingLayer(nn.Module):
    
    def __init__(self, input_size, hidden_size, layer_type, num_layer=1, kernel_size=None):
        
        super(EmbeddingLayer, self).__init__()
        
        self.num_layer = num_layer
        
        self.hidden_size = hidden_size
        self.kernel_size = kernel_size
        
        if layer_type == 'lstm':
            
            self.layer_type = 'lstm'
            #self.title_embedding_layer = nn.LSTM(input_size, hidden_size)
            #self.body_embedding_layer = nn.LSTM(input_size, hidden_size)
            self.embedding_layer = nn.LSTM(input_size, hidden_size, bidirectional=True)
            self.tanh = nn.Tanh()
        
        elif layer_type == 'cnn':
            self.layer_type = 'cnn'
            self.embedding_layer = nn.Sequential(
                        nn.Conv1d(in_channels = 200,
                                  out_channels = self.hidden_size,
                                  kernel_size = kernel_size),
                        nn.Tanh())

    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(self.num_layer*2, batch_size, self.hidden_size)), \
                Variable(torch.zeros(self.num_layer*2, batch_size, self.hidden_size)))

    def forward(self, title, body, title_len, body_len):
            
        if self.layer_type == 'lstm':
            
            
            title_lstm_out, self.title_hidden = self.embedding_layer(title, (self.tanh(self.title_hidden[0]), \
                                                                   self.tanh(self.title_hidden[1])))
            body_lstm_out, self.body_hidden = self.embedding_layer(body, (self.tanh(self.body_hidden[0]), \
                                                                   self.tanh(self.body_hidden[1])))
            
            
            title_mask = Variable(torch.FloatTensor(build_mask3d(title_len)))
            title_embeddings = torch.sum(title_lstm_out * title_mask, dim=0) / torch.sum(title_mask, dim=0)
            
            body_mask = Variable(torch.FloatTensor(build_mask3d(body_len)))
            body_embeddings = torch.sum(body_lstm_out * body_mask, dim=0) / torch.sum(body_mask, dim=0)
            
            embeddings = ( title_embeddings + body_embeddings ) / 2
        
            return embeddings

In [145]:
def save_model(mdl, path):
    # saving model params
    torch.save(mdl.state_dict(), path)

def restore_model(mdl_skeleton, path):
    # restoring params to the mdl skeleton
    mdl_skeleton.load_state_dict(torch.load(path))
    return mdl

# Train

In [160]:
def train(layer_type, embedding_layer, batch_size=25, num_epoch=100, id_set=train_idx_set, eval=True):
    
#     if layer_type == 'lstm':
#         embedding_layer = EmbeddingLayer(200, 240, 'lstm')
#     elif layer_type == 'cnn':
#         embedding_layer = EmbeddingLayer(200, 240, 'cnn', kernel_size=3)
        
    optimizer = torch.optim.Adam(embedding_layer.parameters(), lr=0.001)
    # criterion = torch.nn.MultiMarginLoss()
    criterion = multi_margin_loss()
    
    qids = list(id_set.keys())
    num_batch = len(qids) // batch_size
    
    for epoch in range(1, num_epoch + 1):
        
        for batch_idx in range(1, num_batch + 1):
            
            batch_x_qids = qids[ ( batch_idx - 1 ) * batch_size: batch_idx * batch_size ]
            batch_title, batch_body, title_len, body_len = process_contxt_batch(batch_x_qids, train_idx_set)
            
            if layer_type == 'lstm':
                embedding_layer.title_hidden = embedding_layer.init_hidden(batch_title.shape[1])
                embedding_layer.body_hidden = embedding_layer.init_hidden(batch_body.shape[1])
            
            title_qs = Variable(torch.FloatTensor(batch_title))#, requires_grad=True)
            body_qs = Variable(torch.FloatTensor(batch_body))#, requires_grad=True)
            
            embeddings = embedding_layer(title_qs, body_qs, title_len, body_len)
#             blocked_embeddings = embeddings.view(-1, 22, HIDDEN_DIM * 2)
#             q_vecs = blocked_embeddings[:,0,:]
#             pos_neg_vecs = blocked_embeddings[:,1:,:]
            
#             # cosine similarity
#             scores = torch.sum(torch.unsqueeze(q_vecs, dim=1) * pos_neg_vecs, dim=2) \
#                 / (torch.unsqueeze(torch.sqrt(torch.sum(q_vecs ** 2, dim=1)), dim=1) *\
#                    torch.sqrt(torch.sum( pos_neg_vecs ** 2, dim=2)))
#             target = Variable(torch.zeros(scores.size(0)).type(torch.LongTensor)) 
            
            loss = criterion(embeddings)
            #loss = criterion(scores, target)
            print ('epoch:{}/{}, batch:{}/{}, loss:{}'.format(epoch, num_epoch, batch_idx, num_batch, loss.data[0]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
            if eval and batch_idx % 25 == 0: # lstm for now
                print ('evaluating ....')
                do_eval(dev_map, dev_data, embedding_layer, 'Dev')
                print ('------------------')
                do_eval(test_map, test_data, embedding_layer, 'Test')

In [144]:
model = EmbeddingLayer(200, HIDDEN_DIM, 'lstm') # loss margin = 0.5
train('lstm', model, batch_size=25, num_epoch=10)

epoch:1/10, batch:1/508, loss:0.5026151537895203
epoch:1/10, batch:2/508, loss:0.502778172492981
epoch:1/10, batch:3/508, loss:0.5006894469261169
epoch:1/10, batch:4/508, loss:0.49986058473587036
epoch:1/10, batch:5/508, loss:0.5012609362602234
epoch:1/10, batch:6/508, loss:0.49954280257225037
epoch:1/10, batch:7/508, loss:0.5005972981452942
epoch:1/10, batch:8/508, loss:0.49912452697753906
epoch:1/10, batch:9/508, loss:0.4989265203475952
epoch:1/10, batch:10/508, loss:0.49831342697143555
epoch:1/10, batch:11/508, loss:0.5000904202461243
epoch:1/10, batch:12/508, loss:0.5000025629997253
epoch:1/10, batch:13/508, loss:0.49849817156791687
epoch:1/10, batch:14/508, loss:0.4986778497695923
epoch:1/10, batch:15/508, loss:0.4995615482330322
epoch:1/10, batch:16/508, loss:0.49869799613952637
epoch:1/10, batch:17/508, loss:0.49957963824272156
epoch:1/10, batch:18/508, loss:0.4992402493953705
epoch:1/10, batch:19/508, loss:0.49848881363868713
epoch:1/10, batch:20/508, loss:0.4981097877025604
ep

epoch:1/10, batch:156/508, loss:0.3436394929885864
epoch:1/10, batch:157/508, loss:0.30853477120399475
epoch:1/10, batch:158/508, loss:0.3045615255832672
epoch:1/10, batch:159/508, loss:0.3891583979129791
epoch:1/10, batch:160/508, loss:0.40409210324287415
epoch:1/10, batch:161/508, loss:0.34816503524780273
epoch:1/10, batch:162/508, loss:0.3790401518344879
epoch:1/10, batch:163/508, loss:0.40466371178627014
epoch:1/10, batch:164/508, loss:0.3888358771800995
epoch:1/10, batch:165/508, loss:0.3659917116165161
epoch:1/10, batch:166/508, loss:0.38307780027389526
epoch:1/10, batch:167/508, loss:0.3464403450489044
epoch:1/10, batch:168/508, loss:0.3292739689350128
epoch:1/10, batch:169/508, loss:0.27529025077819824
epoch:1/10, batch:170/508, loss:0.3382743000984192
epoch:1/10, batch:171/508, loss:0.36284366250038147
epoch:1/10, batch:172/508, loss:0.3087102770805359
epoch:1/10, batch:173/508, loss:0.334555983543396
epoch:1/10, batch:174/508, loss:0.29747191071510315
epoch:1/10, batch:175/50

epoch:1/10, batch:310/508, loss:0.3460395932197571
epoch:1/10, batch:311/508, loss:0.31791192293167114
epoch:1/10, batch:312/508, loss:0.3219287097454071
epoch:1/10, batch:313/508, loss:0.299938827753067
epoch:1/10, batch:314/508, loss:0.2909744381904602
epoch:1/10, batch:315/508, loss:0.2915215492248535
epoch:1/10, batch:316/508, loss:0.3896145522594452
epoch:1/10, batch:317/508, loss:0.34968647360801697
epoch:1/10, batch:318/508, loss:0.3023347854614258
epoch:1/10, batch:319/508, loss:0.345836341381073
epoch:1/10, batch:320/508, loss:0.31608057022094727
epoch:1/10, batch:321/508, loss:0.3416295349597931
epoch:1/10, batch:322/508, loss:0.2625827193260193
epoch:1/10, batch:323/508, loss:0.27997732162475586
epoch:1/10, batch:324/508, loss:0.33261600136756897
epoch:1/10, batch:325/508, loss:0.33321765065193176
epoch:1/10, batch:326/508, loss:0.33245015144348145
epoch:1/10, batch:327/508, loss:0.29648256301879883
epoch:1/10, batch:328/508, loss:0.3275316655635834
epoch:1/10, batch:329/508

epoch:1/10, batch:463/508, loss:0.28970664739608765
epoch:1/10, batch:464/508, loss:0.26896750926971436
epoch:1/10, batch:465/508, loss:0.3158222436904907
epoch:1/10, batch:466/508, loss:0.3548130989074707
epoch:1/10, batch:467/508, loss:0.354170024394989
epoch:1/10, batch:468/508, loss:0.30167487263679504
epoch:1/10, batch:469/508, loss:0.3403187394142151
epoch:1/10, batch:470/508, loss:0.2833174765110016
epoch:1/10, batch:471/508, loss:0.3017997443675995
epoch:1/10, batch:472/508, loss:0.31458860635757446
epoch:1/10, batch:473/508, loss:0.2985970675945282
epoch:1/10, batch:474/508, loss:0.24410149455070496
epoch:1/10, batch:475/508, loss:0.2699381411075592
epoch:1/10, batch:476/508, loss:0.2831593453884125
epoch:1/10, batch:477/508, loss:0.33156293630599976
epoch:1/10, batch:478/508, loss:0.3193717300891876
epoch:1/10, batch:479/508, loss:0.2889942526817322
epoch:1/10, batch:480/508, loss:0.27773547172546387
epoch:1/10, batch:481/508, loss:0.2296927571296692
epoch:1/10, batch:482/508

epoch:2/10, batch:110/508, loss:0.2238190621137619
epoch:2/10, batch:111/508, loss:0.22265294194221497
epoch:2/10, batch:112/508, loss:0.2873879373073578
epoch:2/10, batch:113/508, loss:0.2538389563560486
epoch:2/10, batch:114/508, loss:0.35955604910850525
epoch:2/10, batch:115/508, loss:0.33673274517059326
epoch:2/10, batch:116/508, loss:0.2686859965324402
epoch:2/10, batch:117/508, loss:0.3220653533935547
epoch:2/10, batch:118/508, loss:0.30356574058532715
epoch:2/10, batch:119/508, loss:0.27062103152275085
epoch:2/10, batch:120/508, loss:0.32498806715011597
epoch:2/10, batch:121/508, loss:0.29050949215888977
epoch:2/10, batch:122/508, loss:0.2985553443431854
epoch:2/10, batch:123/508, loss:0.2955058515071869
epoch:2/10, batch:124/508, loss:0.25236427783966064
epoch:2/10, batch:125/508, loss:0.2133544534444809
epoch:2/10, batch:126/508, loss:0.26590195298194885
epoch:2/10, batch:127/508, loss:0.27163830399513245
evaluating ....
Dev Performance P@5 0.377639751553
Dev Performance P@1 0

epoch:2/10, batch:257/508, loss:0.2860791087150574
epoch:2/10, batch:258/508, loss:0.32776716351509094
epoch:2/10, batch:259/508, loss:0.31448066234588623
epoch:2/10, batch:260/508, loss:0.2873518466949463
epoch:2/10, batch:261/508, loss:0.2835582494735718
epoch:2/10, batch:262/508, loss:0.21262376010417938
epoch:2/10, batch:263/508, loss:0.23099285364151
epoch:2/10, batch:264/508, loss:0.3300468921661377
epoch:2/10, batch:265/508, loss:0.24554404616355896
epoch:2/10, batch:266/508, loss:0.33389902114868164
epoch:2/10, batch:267/508, loss:0.3125488758087158
epoch:2/10, batch:268/508, loss:0.3255211412906647
epoch:2/10, batch:269/508, loss:0.30353912711143494
epoch:2/10, batch:270/508, loss:0.29936525225639343
epoch:2/10, batch:271/508, loss:0.3105025589466095
epoch:2/10, batch:272/508, loss:0.2713741660118103
epoch:2/10, batch:273/508, loss:0.2454729527235031
epoch:2/10, batch:274/508, loss:0.26846587657928467
epoch:2/10, batch:275/508, loss:0.2546684145927429
epoch:2/10, batch:276/508

epoch:2/10, batch:410/508, loss:0.27778229117393494
epoch:2/10, batch:411/508, loss:0.23350785672664642
epoch:2/10, batch:412/508, loss:0.2629665732383728
epoch:2/10, batch:413/508, loss:0.30703479051589966
epoch:2/10, batch:414/508, loss:0.28826048970222473
epoch:2/10, batch:415/508, loss:0.26076552271842957
epoch:2/10, batch:416/508, loss:0.29807403683662415
epoch:2/10, batch:417/508, loss:0.3000863492488861
epoch:2/10, batch:418/508, loss:0.26149219274520874
epoch:2/10, batch:419/508, loss:0.2685050368309021
epoch:2/10, batch:420/508, loss:0.19896741211414337
epoch:2/10, batch:421/508, loss:0.27996423840522766
epoch:2/10, batch:422/508, loss:0.2815401256084442
epoch:2/10, batch:423/508, loss:0.2817631959915161
epoch:2/10, batch:424/508, loss:0.2883480489253998
epoch:2/10, batch:425/508, loss:0.2439553141593933
epoch:2/10, batch:426/508, loss:0.26374220848083496
epoch:2/10, batch:427/508, loss:0.29972630739212036
epoch:2/10, batch:428/508, loss:0.2568334937095642
epoch:2/10, batch:42

epoch:3/10, batch:56/508, loss:0.27423927187919617
epoch:3/10, batch:57/508, loss:0.27509769797325134
epoch:3/10, batch:58/508, loss:0.18880820274353027
epoch:3/10, batch:59/508, loss:0.3019920289516449
epoch:3/10, batch:60/508, loss:0.28015244007110596
epoch:3/10, batch:61/508, loss:0.2744930684566498
epoch:3/10, batch:62/508, loss:0.20420804619789124
epoch:3/10, batch:63/508, loss:0.30691054463386536
epoch:3/10, batch:64/508, loss:0.32762622833251953
epoch:3/10, batch:65/508, loss:0.3215194642543793
epoch:3/10, batch:66/508, loss:0.30190256237983704
epoch:3/10, batch:67/508, loss:0.22764849662780762
epoch:3/10, batch:68/508, loss:0.2205575853586197
epoch:3/10, batch:69/508, loss:0.21595411002635956
epoch:3/10, batch:70/508, loss:0.23625540733337402
epoch:3/10, batch:71/508, loss:0.2658565938472748
epoch:3/10, batch:72/508, loss:0.26144859194755554
epoch:3/10, batch:73/508, loss:0.19822362065315247
epoch:3/10, batch:74/508, loss:0.26010099053382874
epoch:3/10, batch:75/508, loss:0.322

epoch:3/10, batch:210/508, loss:0.2434147447347641
epoch:3/10, batch:211/508, loss:0.2558598518371582
epoch:3/10, batch:212/508, loss:0.37800005078315735
epoch:3/10, batch:213/508, loss:0.20981402695178986
epoch:3/10, batch:214/508, loss:0.30235588550567627
epoch:3/10, batch:215/508, loss:0.2826959788799286
epoch:3/10, batch:216/508, loss:0.24250176548957825
epoch:3/10, batch:217/508, loss:0.2992197871208191
epoch:3/10, batch:218/508, loss:0.2629643976688385
epoch:3/10, batch:219/508, loss:0.19408287107944489
epoch:3/10, batch:220/508, loss:0.19412493705749512
epoch:3/10, batch:221/508, loss:0.2577503025531769
epoch:3/10, batch:222/508, loss:0.29893097281455994
epoch:3/10, batch:223/508, loss:0.3149651288986206
epoch:3/10, batch:224/508, loss:0.26383447647094727
epoch:3/10, batch:225/508, loss:0.29929471015930176
epoch:3/10, batch:226/508, loss:0.21213072538375854
epoch:3/10, batch:227/508, loss:0.23090480268001556
epoch:3/10, batch:228/508, loss:0.31558331847190857
epoch:3/10, batch:2

epoch:3/10, batch:363/508, loss:0.289591908454895
epoch:3/10, batch:364/508, loss:0.3599427342414856
epoch:3/10, batch:365/508, loss:0.1588636189699173
epoch:3/10, batch:366/508, loss:0.2288312464952469
epoch:3/10, batch:367/508, loss:0.2701561152935028
epoch:3/10, batch:368/508, loss:0.20979104936122894
epoch:3/10, batch:369/508, loss:0.2986891567707062
epoch:3/10, batch:370/508, loss:0.1722392737865448
epoch:3/10, batch:371/508, loss:0.25000759959220886
epoch:3/10, batch:372/508, loss:0.2352309226989746
epoch:3/10, batch:373/508, loss:0.3217352628707886
epoch:3/10, batch:374/508, loss:0.2817058563232422
epoch:3/10, batch:375/508, loss:0.23171792924404144
epoch:3/10, batch:376/508, loss:0.332952618598938
epoch:3/10, batch:377/508, loss:0.2376551628112793
epoch:3/10, batch:378/508, loss:0.19681565463542938
epoch:3/10, batch:379/508, loss:0.2531086802482605
epoch:3/10, batch:380/508, loss:0.2672154903411865
epoch:3/10, batch:381/508, loss:0.18012070655822754
evaluating ....
Dev Performa

epoch:4/10, batch:2/508, loss:0.24041998386383057
epoch:4/10, batch:3/508, loss:0.3200435936450958
epoch:4/10, batch:4/508, loss:0.2114938497543335
epoch:4/10, batch:5/508, loss:0.30499574542045593
epoch:4/10, batch:6/508, loss:0.31289756298065186
epoch:4/10, batch:7/508, loss:0.24708640575408936
epoch:4/10, batch:8/508, loss:0.1642632633447647
epoch:4/10, batch:9/508, loss:0.297350138425827
epoch:4/10, batch:10/508, loss:0.1795019954442978
epoch:4/10, batch:11/508, loss:0.21369075775146484
epoch:4/10, batch:12/508, loss:0.2803500294685364
epoch:4/10, batch:13/508, loss:0.1691911369562149
epoch:4/10, batch:14/508, loss:0.22765208780765533
epoch:4/10, batch:15/508, loss:0.23459947109222412
epoch:4/10, batch:16/508, loss:0.19273756444454193
epoch:4/10, batch:17/508, loss:0.2851704955101013
epoch:4/10, batch:18/508, loss:0.29575687646865845
epoch:4/10, batch:19/508, loss:0.17685487866401672
epoch:4/10, batch:20/508, loss:0.17783381044864655
epoch:4/10, batch:21/508, loss:0.239254459738731

epoch:4/10, batch:157/508, loss:0.23594997823238373
epoch:4/10, batch:158/508, loss:0.19763651490211487
epoch:4/10, batch:159/508, loss:0.25291723012924194
epoch:4/10, batch:160/508, loss:0.2525399625301361
epoch:4/10, batch:161/508, loss:0.24783939123153687
epoch:4/10, batch:162/508, loss:0.278709352016449
epoch:4/10, batch:163/508, loss:0.3223874270915985
epoch:4/10, batch:164/508, loss:0.28209877014160156
epoch:4/10, batch:165/508, loss:0.2693496346473694
epoch:4/10, batch:166/508, loss:0.3017217218875885
epoch:4/10, batch:167/508, loss:0.23064807057380676
epoch:4/10, batch:168/508, loss:0.2436118721961975
epoch:4/10, batch:169/508, loss:0.17443174123764038
epoch:4/10, batch:170/508, loss:0.21813975274562836
epoch:4/10, batch:171/508, loss:0.30737099051475525
epoch:4/10, batch:172/508, loss:0.23890022933483124
epoch:4/10, batch:173/508, loss:0.23747028410434723
epoch:4/10, batch:174/508, loss:0.20366856455802917
epoch:4/10, batch:175/508, loss:0.26482436060905457
epoch:4/10, batch:1

epoch:4/10, batch:310/508, loss:0.21495231986045837
epoch:4/10, batch:311/508, loss:0.20262759923934937
epoch:4/10, batch:312/508, loss:0.23852457106113434
epoch:4/10, batch:313/508, loss:0.19837042689323425
epoch:4/10, batch:314/508, loss:0.19560427963733673
epoch:4/10, batch:315/508, loss:0.21853548288345337
epoch:4/10, batch:316/508, loss:0.28820592164993286
epoch:4/10, batch:317/508, loss:0.32830867171287537
epoch:4/10, batch:318/508, loss:0.22656570374965668
epoch:4/10, batch:319/508, loss:0.25254860520362854
epoch:4/10, batch:320/508, loss:0.23788179457187653
epoch:4/10, batch:321/508, loss:0.24188975989818573
epoch:4/10, batch:322/508, loss:0.2470225840806961
epoch:4/10, batch:323/508, loss:0.2317601889371872
epoch:4/10, batch:324/508, loss:0.23982089757919312
epoch:4/10, batch:325/508, loss:0.27079546451568604
epoch:4/10, batch:326/508, loss:0.2692287266254425
epoch:4/10, batch:327/508, loss:0.20037007331848145
epoch:4/10, batch:328/508, loss:0.23771953582763672
epoch:4/10, bat

epoch:4/10, batch:463/508, loss:0.2724207937717438
epoch:4/10, batch:464/508, loss:0.1995088905096054
epoch:4/10, batch:465/508, loss:0.2377437949180603
epoch:4/10, batch:466/508, loss:0.25949904322624207
epoch:4/10, batch:467/508, loss:0.27278339862823486
epoch:4/10, batch:468/508, loss:0.2623048722743988
epoch:4/10, batch:469/508, loss:0.2593543529510498
epoch:4/10, batch:470/508, loss:0.28782618045806885
epoch:4/10, batch:471/508, loss:0.2801878750324249
epoch:4/10, batch:472/508, loss:0.2774294912815094
epoch:4/10, batch:473/508, loss:0.22916428744792938
epoch:4/10, batch:474/508, loss:0.21200969815254211
epoch:4/10, batch:475/508, loss:0.2268708050251007
epoch:4/10, batch:476/508, loss:0.21956926584243774
epoch:4/10, batch:477/508, loss:0.3093561828136444
epoch:4/10, batch:478/508, loss:0.25246912240982056
epoch:4/10, batch:479/508, loss:0.22743161022663116
epoch:4/10, batch:480/508, loss:0.19321569800376892
epoch:4/10, batch:481/508, loss:0.18387193977832794
epoch:4/10, batch:482

epoch:5/10, batch:110/508, loss:0.2792927920818329
epoch:5/10, batch:111/508, loss:0.18742507696151733
epoch:5/10, batch:112/508, loss:0.22436989843845367
epoch:5/10, batch:113/508, loss:0.20207613706588745
epoch:5/10, batch:114/508, loss:0.29226458072662354
epoch:5/10, batch:115/508, loss:0.23603539168834686
epoch:5/10, batch:116/508, loss:0.20349618792533875
epoch:5/10, batch:117/508, loss:0.23915785551071167
epoch:5/10, batch:118/508, loss:0.25877514481544495
epoch:5/10, batch:119/508, loss:0.1663273125886917
epoch:5/10, batch:120/508, loss:0.26819634437561035
epoch:5/10, batch:121/508, loss:0.22131703794002533
epoch:5/10, batch:122/508, loss:0.24279989302158356
epoch:5/10, batch:123/508, loss:0.24301575124263763
epoch:5/10, batch:124/508, loss:0.20106272399425507
epoch:5/10, batch:125/508, loss:0.20618030428886414
epoch:5/10, batch:126/508, loss:0.22747381031513214
epoch:5/10, batch:127/508, loss:0.18711744248867035
evaluating ....
Dev Performance P@5 0.398757763975
Dev Performance

epoch:5/10, batch:256/508, loss:0.17769354581832886
epoch:5/10, batch:257/508, loss:0.26144930720329285
epoch:5/10, batch:258/508, loss:0.2440420687198639
epoch:5/10, batch:259/508, loss:0.2430574595928192
epoch:5/10, batch:260/508, loss:0.24134106934070587
epoch:5/10, batch:261/508, loss:0.24582868814468384
epoch:5/10, batch:262/508, loss:0.1772305816411972
epoch:5/10, batch:263/508, loss:0.23247958719730377
epoch:5/10, batch:264/508, loss:0.27415338158607483
epoch:5/10, batch:265/508, loss:0.16723093390464783
epoch:5/10, batch:266/508, loss:0.2662244141101837
epoch:5/10, batch:267/508, loss:0.2843785285949707
epoch:5/10, batch:268/508, loss:0.25946709513664246
epoch:5/10, batch:269/508, loss:0.22602175176143646
epoch:5/10, batch:270/508, loss:0.2564050853252411
epoch:5/10, batch:271/508, loss:0.20254814624786377
epoch:5/10, batch:272/508, loss:0.23773017525672913
epoch:5/10, batch:273/508, loss:0.21339692175388336
epoch:5/10, batch:274/508, loss:0.20723706483840942
epoch:5/10, batch:

KeyboardInterrupt: 

In [147]:
save_model(model, 'models/lstm_bi_epoch=4.5_margin=.5_hidden=120')

In [None]:
model_margin_p3 = EmbeddingLayer(200, HIDDEN_DIM, 'lstm')
train('lstm', model, batch_size=25, num_epoch=10)

epoch:1/10, batch:1/508, loss:0.06204039976000786
epoch:1/10, batch:2/508, loss:0.06698644161224365
epoch:1/10, batch:3/508, loss:0.08204396069049835
epoch:1/10, batch:4/508, loss:0.07271836698055267
epoch:1/10, batch:5/508, loss:0.11348847299814224
epoch:1/10, batch:6/508, loss:0.12347452342510223
epoch:1/10, batch:7/508, loss:0.07170307636260986
epoch:1/10, batch:8/508, loss:0.05405590310692787
epoch:1/10, batch:9/508, loss:0.13060428202152252
epoch:1/10, batch:10/508, loss:0.07260049134492874
epoch:1/10, batch:11/508, loss:0.08073266595602036
epoch:1/10, batch:12/508, loss:0.14527538418769836
epoch:1/10, batch:13/508, loss:0.05466698110103607
epoch:1/10, batch:14/508, loss:0.08360808342695236
epoch:1/10, batch:15/508, loss:0.0836087018251419
epoch:1/10, batch:16/508, loss:0.09609977155923843
epoch:1/10, batch:17/508, loss:0.10398983210325241
epoch:1/10, batch:18/508, loss:0.1288212686777115
epoch:1/10, batch:19/508, loss:0.03824116662144661
epoch:1/10, batch:20/508, loss:0.034473661

epoch:1/10, batch:129/508, loss:0.07693225145339966
epoch:1/10, batch:130/508, loss:0.13850216567516327
epoch:1/10, batch:131/508, loss:0.1792794018983841
epoch:1/10, batch:132/508, loss:0.09958145767450333
epoch:1/10, batch:133/508, loss:0.08008226007223129
epoch:1/10, batch:134/508, loss:0.1331702023744583
epoch:1/10, batch:135/508, loss:0.07687525451183319
epoch:1/10, batch:136/508, loss:0.08159011602401733
epoch:1/10, batch:137/508, loss:0.08214867860078812
epoch:1/10, batch:138/508, loss:0.08653660118579865
epoch:1/10, batch:139/508, loss:0.08022867143154144
epoch:1/10, batch:140/508, loss:0.0887611135840416
epoch:1/10, batch:141/508, loss:0.12374649196863174
epoch:1/10, batch:142/508, loss:0.08957083523273468
epoch:1/10, batch:143/508, loss:0.07674036175012589
epoch:1/10, batch:144/508, loss:0.07492111623287201
epoch:1/10, batch:145/508, loss:0.06743984669446945
epoch:1/10, batch:146/508, loss:0.09372374415397644
epoch:1/10, batch:147/508, loss:0.07385239005088806
epoch:1/10, bat

epoch:1/10, batch:255/508, loss:0.06415783613920212
epoch:1/10, batch:256/508, loss:0.11833930760622025
epoch:1/10, batch:257/508, loss:0.1142667680978775
epoch:1/10, batch:258/508, loss:0.12293071299791336
epoch:1/10, batch:259/508, loss:0.13567987084388733
epoch:1/10, batch:260/508, loss:0.13079805672168732
epoch:1/10, batch:261/508, loss:0.0925537720322609
epoch:1/10, batch:262/508, loss:0.04757543280720711
epoch:1/10, batch:263/508, loss:0.07988748699426651
epoch:1/10, batch:264/508, loss:0.14379945397377014
epoch:1/10, batch:265/508, loss:0.055405762046575546
epoch:1/10, batch:266/508, loss:0.11227753013372421
epoch:1/10, batch:267/508, loss:0.11142270267009735
epoch:1/10, batch:268/508, loss:0.11688664555549622
epoch:1/10, batch:269/508, loss:0.07529149204492569
epoch:1/10, batch:270/508, loss:0.09170898795127869
epoch:1/10, batch:271/508, loss:0.11453507095575333
epoch:1/10, batch:272/508, loss:0.07868916541337967
epoch:1/10, batch:273/508, loss:0.07169362157583237
epoch:1/10, b

epoch:1/10, batch:381/508, loss:0.05859388783574104
epoch:1/10, batch:382/508, loss:0.1103350892663002
epoch:1/10, batch:383/508, loss:0.08043211698532104
epoch:1/10, batch:384/508, loss:0.06245983764529228
epoch:1/10, batch:385/508, loss:0.11905223876237869
epoch:1/10, batch:386/508, loss:0.07976298034191132
epoch:1/10, batch:387/508, loss:0.0702722817659378
epoch:1/10, batch:388/508, loss:0.09807973355054855
epoch:1/10, batch:389/508, loss:0.13630197942256927
epoch:1/10, batch:390/508, loss:0.06581863760948181
epoch:1/10, batch:391/508, loss:0.06663747876882553
epoch:1/10, batch:392/508, loss:0.08755485713481903
epoch:1/10, batch:393/508, loss:0.11267269402742386
epoch:1/10, batch:394/508, loss:0.10319242626428604
epoch:1/10, batch:395/508, loss:0.17395727336406708
epoch:1/10, batch:396/508, loss:0.1303078532218933
epoch:1/10, batch:397/508, loss:0.07981802523136139
epoch:1/10, batch:398/508, loss:0.07730397582054138
epoch:1/10, batch:399/508, loss:0.14149557054042816
epoch:1/10, bat

epoch:1/10, batch:507/508, loss:0.08553814888000488
epoch:1/10, batch:508/508, loss:0.06603405624628067
epoch:2/10, batch:1/508, loss:0.0494852289557457
epoch:2/10, batch:2/508, loss:0.09168995916843414
epoch:2/10, batch:3/508, loss:0.1366259753704071
epoch:2/10, batch:4/508, loss:0.06433810293674469
epoch:2/10, batch:5/508, loss:0.11617773026227951
epoch:2/10, batch:6/508, loss:0.10645077377557755
epoch:2/10, batch:7/508, loss:0.06478830426931381
epoch:2/10, batch:8/508, loss:0.057500630617141724
epoch:2/10, batch:9/508, loss:0.09177706390619278
epoch:2/10, batch:10/508, loss:0.05528249964118004
epoch:2/10, batch:11/508, loss:0.08676209300756454
epoch:2/10, batch:12/508, loss:0.13217806816101074
epoch:2/10, batch:13/508, loss:0.039482150226831436
epoch:2/10, batch:14/508, loss:0.08643946796655655
epoch:2/10, batch:15/508, loss:0.06677962094545364
epoch:2/10, batch:16/508, loss:0.06086498126387596
epoch:2/10, batch:17/508, loss:0.09507714956998825
epoch:2/10, batch:18/508, loss:0.11380

# Debugging

In [155]:
qids = list(train_idx_set.keys())[:25]
t, b, tl, bl = process_contxt_batch(batch_x_qids, train_idx_set)

In [157]:
embedding_layer = EmbeddingLayer(200, 240, 'lstm')
embedding_layer.title_hidden = embedding_layer.init_hidden(t.shape[1])
embedding_layer.body_hidden = embedding_layer.init_hidden(b.shape[1])

In [158]:
title_qs = Variable(torch.FloatTensor(t))
body_qs = Variable(torch.FloatTensor(b))
embeddings = embedding_layer(title_qs, body_qs, tl, bl)

In [175]:
target = Variable(torch.LongTensor([1,1] + [0]*20))

In [176]:
loss = torch.nn.MultiMarginLoss()

In [179]:
loss(embeddings[:22], target)

Variable containing:
 1.0260
[torch.FloatTensor of size 1]

In [196]:
blocked_embeddings = embeddings.view(-1, 22, 240)

In [198]:
q_vecs = blocked_embeddings[:,0,:]
pos_vecs = blocked_embeddings[:,1,:]
neg_vecs = blocked_embeddings[:,2:,:]

In [199]:
pos_scores = torch.sum(q_vecs * pos_vecs, dim=1) / (torch.sqrt(torch.sum(q_vecs ** 2, dim=1)) \
                                    * torch.sqrt(torch.sum(pos_vecs ** 2, dim=1)))
neg_scores = torch.sum(torch.unsqueeze(q_vecs, dim=1) * neg_vecs, dim=2) \
/ (torch.unsqueeze(torch.sqrt(torch.sum(q_vecs ** 2, dim=1)), dim=1) * torch.sqrt(torch.sum( neg_vecs ** 2, dim=2)))

In [203]:
pos_scores[0], neg_scores[0]

(Variable containing:
  0.9956
 [torch.FloatTensor of size 1], Variable containing:
  0.9919
  0.9907
  0.9932
  0.9865
  0.9866
  0.9890
  0.9923
  0.9835
  0.9876
  0.9871
  0.9884
  0.9846
  0.9868
  0.9855
  0.9879
  0.9884
  0.9886
  0.9864
  0.9892
  0.9893
 [torch.FloatTensor of size 20])

In [223]:
q_vecs = blocked_embeddings[:,0,:]

In [224]:
pn_vecs = blocked_embeddings[:,1:,:]

In [225]:
scores = torch.sum(torch.unsqueeze(q_vecs, dim=1) * pn_vecs, dim=2) \
/ (torch.unsqueeze(torch.sqrt(torch.sum(q_vecs ** 2, dim=1)), dim=1) * torch.sqrt(torch.sum( pn_vecs ** 2, dim=2)))

In [229]:
scores

Variable containing:

Columns 0 to 9 
 0.9956  0.9919  0.9907  0.9932  0.9865  0.9866  0.9890  0.9923  0.9835  0.9876
 0.9903  0.9830  0.9841  0.9789  0.9809  0.9809  0.9718  0.9705  0.9728  0.9721
 0.9930  0.9897  0.9855  0.9897  0.9890  0.9820  0.9798  0.9854  0.9760  0.9834
 0.9867  0.9819  0.9823  0.9876  0.9819  0.9851  0.9798  0.9759  0.9894  0.9874
 0.9920  0.9864  0.9870  0.9819  0.9849  0.9806  0.9822  0.9809  0.9870  0.9651
 0.9929  0.9910  0.9900  0.9828  0.9876  0.9922  0.9914  0.9906  0.9916  0.9895
 0.9964  0.9934  0.9812  0.9916  0.9866  0.9881  0.9790  0.9898  0.9876  0.9864
 0.9941  0.9806  0.9755  0.9790  0.9821  0.9760  0.9773  0.9871  0.9782  0.9830
 0.9853  0.9871  0.9801  0.9824  0.9804  0.9677  0.9870  0.9842  0.9805  0.9844
 0.9833  0.9861  0.9755  0.9871  0.9849  0.9815  0.9762  0.9896  0.9809  0.9874
 0.9978  0.9925  0.9840  0.9818  0.9911  0.9924  0.9908  0.9857  0.9918  0.9911
 0.9929  0.9882  0.9847  0.9853  0.9857  0.9858  0.9930  0.9908  0.9923  0.9881
 0

In [230]:
criterion = torch.nn.MultiMarginLoss()

In [234]:
target = Variable(torch.zeros(scores.size(0)).type(torch.LongTensor)) 
criterion(scores, target)

Variable containing:
 0.9469
[torch.FloatTensor of size 1]