In [1]:
from tqdm import tqdm

import numpy as np
import numpy.random as random

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.nn.utils.rnn import pad_sequence

import io

In [2]:
def load_embeddings(fname, get_embeddings=True, get_w2i=False, get_i2w=False, skip_first_line=True):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    
    if skip_first_line:
        fin.readline()
    
    num_embeddings = 0
    
    word2idx = {}
    idx2word = {}

    embeddings = []

    for line in fin:
        line = line.rstrip().split(' ')
        
        if get_w2i:
            word2idx[line[0]] = num_embeddings
        if get_i2w:
            idx2word[num_embeddings] = line[0]
        if get_embeddings:       
            embeddings.append([float(num) for num in line[1:]])
        
        num_embeddings += 1
        
        
    return torch.FloatTensor(embeddings), word2idx, idx2word

In [3]:
word2idx = load_embeddings('../embeddings/wiki-news-300d-1M.vec', get_embeddings=False, get_w2i=True)[1]

In [4]:
class MultiplicativeAttention(nn.Module):
      
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        
        attn = torch.matmul(q , k.transpose(-2, -1) / math.sqrt(q.size(-1)))
        
        if mask is not None:
            attn = attn.masked_fill(mask.unsqueeze(1) == 1, -1e9)
        
        attn = self.dropout(F.softmax(attn, dim=-1))        
        res = torch.matmul(attn, v)

        return res, attn

class AdditiveSelfAttention(nn.Module):
    
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        
        self.w = nn.Linear(d_model, d_model)
        self.q = torch.nn.Parameter(torch.FloatTensor(d_model).uniform_(-0.1, 0.1))
    
    def forward(self, x, mask=None):
        attn = torch.tanh(self.dropout(self.w(x)))        
        attn = torch.matmul(attn, self.q)
        
        if mask is not None:
            attn = attn.masked_fill(mask == 1, -1e9)
        
        attn = self.dropout(F.softmax(attn, dim=-1))

        
        res = torch.einsum('ijk, ij->ik', x, attn)
        return res, attn

    
class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_model, num_heads, d_qk, d_v, track_agreement=False, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_qk = d_qk
        self.d_v = d_v
        
        self.num_heads = num_heads
        
        self.dropout = nn.Dropout(dropout)
        
        self.w_q = nn.Linear(d_model, num_heads * d_qk, bias=False)
        self.w_k = nn.Linear(d_model, num_heads * d_qk, bias=False)
        self.w_v = nn.Linear(d_model, num_heads * d_v, bias=False)
        
        self.w_fc = nn.Linear(num_heads * d_v, d_model, bias=False)
        
        self.attention = MultiplicativeAttention(dropout=dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
    
        self.track_agreement = track_agreement
        self.v_agreement = 0

    def forward(self, q, k, v, mask=None):     
        batch_size = q.shape[0]
        seq_size = q.shape[1]
        
        q_proj = self.w_q(q).view(q.shape[0], q.shape[1], self.num_heads, self.d_qk)
        k_proj = self.w_k(k).view(k.shape[0], k.shape[1], self.num_heads, self.d_qk)
        v_proj = self.w_v(v).view(v.shape[0], v.shape[1], self.num_heads, self.d_v) 

        if self.track_agreement:
            self.v_agreement += torch.einsum('bshd, bsnd->', F.normalize(v_proj, dim=3), F.normalize(v_proj, dim=3)) / self.num_heads**2

        if mask is None:
            q, attn = self.attention(q_proj.transpose(1, 2), k_proj.transpose(1, 2), v_proj.transpose(1, 2))
        else:
            q, attn = self.attention(q_proj.transpose(1, 2), k_proj.transpose(1, 2), v_proj.transpose(1, 2), mask.unsqueeze(1))
        
        q = q.transpose(1, 2).contiguous()
        q = q.view(batch_size, seq_size, -1)

        q = self.dropout(self.w_fc(q))

        q = self.layer_norm(q)
        
        return q, attn

    def clear_agreement(self):
        self.v_agreement = 0

class NonlinearFF(nn.Module):
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)
        self.w_2 = nn.Linear(d_hid, d_in)
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)

        x = self.layer_norm(x)

        return x
    
class TitleEmbedding(nn.Module):
    def __init__(self, num_embeddings, d_model, num_heads, d_qk, d_v, d_hid=None, embeddings=None, track_agreement=False, padding_idx=0, dropout=0.1):
        super().__init__()

        if embeddings is not None:
            self.embeddings = nn.Embedding.from_pretrained(embeddings, freeze=False, sparse=True, padding_idx=padding_idx)
        else:
            self.embeddings = nn.Embedding(num_embeddings, d_model, sparse=True, padding_idx=0)
            
        self.mh_attn = MultiHeadAttention(d_model, num_heads, d_qk, d_v, track_agreement=track_agreement, dropout=dropout)
        self.nff = NonlinearFF(d_model, d_hid if d_hid is not None else d_model * 4, dropout=dropout)
        self.add_attn = AdditiveSelfAttention(d_model, dropout=dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        
        self.padding_idx = padding_idx
        
    def forward(self, title):    
        mask = (title == self.padding_idx).byte()
        
        q = k = v = self.embeddings(title)
        title, attn = self.mh_attn(q, k ,v, mask=mask)
        
        title = self.nff(title)
        title, add_attn = self.add_attn(title, mask=mask)
        
        title = self.layer_norm(title)
        
        return title
    
    def load_embeddings(embeddings):
        self.embeddings = nn.Embedding.from_pretrained(embeddings, freeze=False, sparse=True)
    
    


In [5]:
class NegativeSamplingLoss(nn.Module):
    def __init__(self, sample_pool, embedding, num_samples, device='cpu'):
        super(NegativeSamplingLoss, self).__init__()
        
        self.sample_pool = sample_pool
        self.num_samples = num_samples
        
        self.embedding = embedding
        
        self.device = device
        
    def forward(self, x, y):
        batch_size = x.shape[0]
        
        sample_indices = torch.FloatTensor(batch_size * self.num_samples).uniform_(0, len(self.sample_pool) - 1).long()        
        sample_indices, tmp = torch.broadcast_tensors(sample_indices.unsqueeze(1), 
                                                      torch.arange(batch_size * self.num_samples * self.sample_pool.shape[1])
                                                      .view(batch_size * self.num_samples, self.sample_pool.shape[1]))
        
        n = torch.gather(self.sample_pool, 0, sample_indices).to(self.device)
        
        x = self.embedding(x)
        y = self.embedding(y)
        n = self.embedding(n).neg()
        
        target_loss = (torch.einsum('ij,ij->i', x, y) / x.shape[-1]).sigmoid().mean() 
        noise_loss = (torch.einsum('ij, ikj->ik', x, n.view(batch_size, self.num_samples, n.shape[-1])) / (x.shape[-1])).sigmoid().log().sum(1).mean()
        
        return -(target_loss + noise_loss)        

In [6]:
dataset = torch.load('../data/dataset.pt')

In [8]:
device = torch.device('cuda:0')

title_embedding = torch.load('../models/model_reg_init.pt')
nss = NegativeSamplingLoss(dataset.tensors[1], title_embedding, 5, device=device).to(device)

In [None]:
#torch.save(title_embedding, '../models/model_reg_init.pt')

In [None]:
#title_embedding.to(torch.device('cpu'))
#nss.to(torch.device('cpu'))

In [11]:
class MultipleOptimizer:
    def __init__(self, *op):
        self.optimizers = op

    def zero_grad(self):
        for op in self.optimizers:
            op.zero_grad()

    def step(self):
        for op in self.optimizers:
            op.step()

sparse_params = []
dense_params = []

for name, param in title_embedding.named_parameters():
    if name == 'embeddings.weight':
        sparse_params.append(param)
    else:
        dense_params.append(param)
        
opt_dense = torch.optim.Adam(dense_params, lr=1e-4, weight_decay=0.0001)
opt_sparse = torch.optim.SparseAdam(sparse_params, lr=1e-4)

optimizer = MultipleOptimizer(opt_sparse, opt_dense)

train_loader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=64)

In [12]:
losses = []

In [13]:
num_epochs = 5
batch_size = 64
reg_coeff = 1.

for epoch in range(num_epochs):
        epoch_loss = 0        
        
        for idx, batch in enumerate(train_loader):
            x, y = batch[0].to(device), batch[1].to(device)
            
            optimizer.zero_grad()
            loss = nss(x, y)
            
            loss += reg_coeff * title_embedding.mh_attn.v_agreement
            title_embedding.mh_attn.clear_agreement()
            
            if loss.isnan() or loss.isinf():
                continue
            
            loss.backward()
            optimizer.step()
        
            epoch_loss += loss.item()
            
            if (idx + 1) % batch_size == 0:
                print('loss at batch {}: {}'.format((idx+1), loss.item()))
                losses.append(loss.item())
                
            
            if (idx + 1) % 16384 == 0:
                torch.save(title_embedding, '../models/model{}_{}.pt'.format(epoch, (int)(idx+1) / 16384))
            
        print('loss at end of epoch {}: {}'.format(epoch, epoch_loss/len(train_loader)))
        
        with open('../models/model{}.pt'.format(epoch+1), 'wb') as of:
            torch.save(title_embedding, of)

loss at batch 64: 66.75123596191406
loss at batch 128: 46.99741744995117
loss at batch 192: 33.35817337036133
loss at batch 256: 26.671594619750977
loss at batch 320: 21.124557495117188
loss at batch 384: 17.891103744506836
loss at batch 448: 14.759486198425293
loss at batch 512: 13.05744457244873
loss at batch 576: 11.62479019165039
loss at batch 640: 10.634286880493164
loss at batch 704: 9.653124809265137
loss at batch 768: 8.771020889282227
loss at batch 832: 8.396596908569336
loss at batch 896: 7.720708847045898
loss at batch 960: 7.804649353027344
loss at batch 1024: 7.037797927856445
loss at batch 1088: 6.838610649108887
loss at batch 1152: 6.1069536209106445
loss at batch 1216: 6.123982906341553
loss at batch 1280: 5.810311794281006
loss at batch 1344: 5.484839916229248
loss at batch 1408: 5.290788173675537
loss at batch 1472: 5.220921993255615
loss at batch 1536: 5.000507354736328
loss at batch 1600: 5.039775371551514
loss at batch 1664: 4.806284427642822
loss at batch 1728: 4.

loss at batch 13696: 2.816551923751831
loss at batch 13760: 2.8801352977752686
loss at batch 13824: 2.9046051502227783
loss at batch 13888: 2.9112181663513184
loss at batch 13952: 2.855391502380371
loss at batch 14016: 2.899574041366577
loss at batch 14080: 2.8524420261383057
loss at batch 14144: 2.88330340385437
loss at batch 14208: 2.92085337638855
loss at batch 14272: 2.88553524017334
loss at batch 14336: 2.8869450092315674
loss at batch 14400: 2.81593656539917
loss at batch 14464: 2.8988704681396484
loss at batch 14528: 2.7620739936828613
loss at batch 14592: 2.8919174671173096
loss at batch 14656: 3.039677619934082
loss at batch 14720: 2.925189733505249
loss at batch 14784: 2.9111697673797607
loss at batch 14848: 2.8044509887695312
loss at batch 14912: 2.9497694969177246
loss at batch 14976: 2.844222068786621
loss at batch 15040: 2.9476261138916016
loss at batch 15104: 2.8522894382476807
loss at batch 15168: 2.8199918270111084
loss at batch 15232: 2.910844087600708
loss at batch 1

loss at batch 27008: 2.7816121578216553
loss at batch 27072: 2.727949857711792
loss at batch 27136: 2.8570003509521484
loss at batch 27200: 2.8516318798065186
loss at batch 27264: 2.810133218765259
loss at batch 27328: 2.765998125076294
loss at batch 27392: 2.8323867321014404
loss at batch 27456: 2.8149282932281494
loss at batch 27520: 2.7756378650665283
loss at batch 27584: 2.8505735397338867
loss at batch 27648: 2.8033416271209717
loss at batch 27712: 2.823866128921509
loss at batch 27776: 2.848714590072632
loss at batch 27840: 2.7931625843048096
loss at batch 27904: 2.9010963439941406
loss at batch 27968: 2.7272305488586426
loss at batch 28032: 2.766465187072754
loss at batch 28096: 2.894645929336548
loss at batch 28160: 2.7393741607666016
loss at batch 28224: 2.7678189277648926
loss at batch 28288: 2.773531198501587
loss at batch 28352: 2.8636231422424316
loss at batch 28416: 2.844917058944702
loss at batch 28480: 2.769022226333618
loss at batch 28544: 2.847860336303711
loss at bat

loss at batch 40320: 2.8124825954437256
loss at batch 40384: 2.8456873893737793
loss at batch 40448: 2.7978808879852295
loss at batch 40512: 2.763720750808716
loss at batch 40576: 2.817284345626831
loss at batch 40640: 2.7488749027252197
loss at batch 40704: 2.700432062149048
loss at batch 40768: 2.791775941848755
loss at batch 40832: 2.9596965312957764
loss at batch 40896: 2.7570412158966064
loss at batch 40960: 2.791029214859009
loss at batch 41024: 2.91772198677063
loss at batch 41088: 2.717960834503174
loss at batch 41152: 2.8026070594787598
loss at batch 41216: 2.7738583087921143
loss at batch 41280: 2.881950855255127
loss at batch 41344: 2.7257463932037354
loss at batch 41408: 2.7600693702697754
loss at batch 41472: 2.782583236694336
loss at batch 41536: 2.771139144897461
loss at batch 41600: 2.9173591136932373
loss at batch 41664: 2.7050440311431885
loss at batch 41728: 2.819368362426758
loss at batch 41792: 2.669616222381592
loss at batch 41856: 2.784446954727173
loss at batch 

loss at batch 53632: 2.915766477584839
loss at batch 53696: 2.766934633255005
loss at batch 53760: 2.825538396835327
loss at batch 53824: 2.7493653297424316
loss at batch 53888: 2.7898037433624268
loss at batch 53952: 2.7268104553222656
loss at batch 54016: 2.7507729530334473
loss at batch 54080: 2.773775100708008
loss at batch 54144: 2.6594395637512207
loss at batch 54208: 2.846226692199707
loss at batch 54272: 2.761691093444824
loss at batch 54336: 2.692112445831299
loss at batch 54400: 2.6366827487945557
loss at batch 54464: 2.8336589336395264
loss at batch 54528: 2.802342414855957
loss at batch 54592: 2.630979537963867
loss at batch 54656: 2.6844379901885986
loss at batch 54720: 2.6554057598114014
loss at batch 54784: 2.784353256225586
loss at batch 54848: 2.659458637237549
loss at batch 54912: 2.7063894271850586
loss at batch 54976: 2.699251651763916
loss at batch 55040: 2.688809394836426
loss at batch 55104: 2.715444564819336
loss at batch 55168: 2.6963589191436768
loss at batch 

loss at batch 66944: 2.9473347663879395
loss at batch 67008: 2.5864102840423584
loss at batch 67072: 2.520852565765381
loss at batch 67136: 2.799804210662842
loss at batch 67200: 2.6862597465515137
loss at batch 67264: 2.7933738231658936
loss at batch 67328: 2.6698970794677734
loss at batch 67392: 2.5951974391937256
loss at batch 67456: 2.620413303375244
loss at batch 67520: 2.779829740524292
loss at batch 67584: 2.853372097015381
loss at batch 67648: 2.632086992263794
loss at batch 67712: 2.673696517944336
loss at batch 67776: 2.6419601440429688
loss at batch 67840: 2.661815881729126
loss at batch 67904: 2.5982353687286377
loss at batch 67968: 2.6739749908447266
loss at batch 68032: 2.8772528171539307
loss at batch 68096: 2.6666688919067383
loss at batch 68160: 2.51983642578125
loss at batch 68224: 2.785557985305786
loss at batch 68288: 2.6238722801208496
loss at batch 68352: 2.66142201423645
loss at batch 68416: 2.566540241241455
loss at batch 68480: 2.5835258960723877
loss at batch 

loss at batch 80256: 2.821998119354248
loss at batch 80320: 2.7124440670013428
loss at batch 80384: 2.6940081119537354
loss at batch 80448: 2.4829349517822266
loss at batch 80512: 2.4458892345428467
loss at batch 80576: 2.5071804523468018
loss at batch 80640: 2.664264440536499
loss at batch 80704: 2.4588985443115234
loss at batch 80768: 2.549326181411743
loss at batch 80832: 2.6436219215393066
loss at batch 80896: 2.617842674255371
loss at batch 80960: 2.3521270751953125
loss at batch 81024: 2.467338800430298
loss at batch 81088: 2.5303940773010254
loss at batch 81152: 2.782472610473633
loss at batch 81216: 2.5059385299682617
loss at batch 81280: 2.447650671005249
loss at batch 81344: 2.5330820083618164
loss at batch 81408: 2.523319959640503
loss at batch 81472: 2.422126293182373
loss at batch 81536: 2.5131874084472656
loss at batch 81600: 2.5770139694213867
loss at batch 81664: 2.3259060382843018
loss at batch 81728: 2.388587236404419
loss at batch 81792: 2.307211399078369
loss at bat

loss at batch 93568: 2.3390607833862305
loss at batch 93632: 2.2263875007629395
loss at batch 93696: 2.5165069103240967
loss at batch 93760: 2.434410572052002
loss at batch 93824: 2.813323974609375
loss at batch 93888: 2.298464298248291
loss at batch 93952: 2.735755443572998
loss at batch 94016: 2.3553519248962402
loss at batch 94080: 2.350369930267334
loss at batch 94144: 2.3329994678497314
loss at batch 94208: 2.192598342895508
loss at batch 94272: 2.205881357192993
loss at batch 94336: 2.3918635845184326
loss at batch 94400: 2.5609264373779297
loss at batch 94464: 2.3739256858825684
loss at batch 94528: 2.577970504760742
loss at batch 94592: 2.266463279724121
loss at batch 94656: 2.2190585136413574
loss at batch 94720: 2.466902732849121
loss at batch 94784: 2.3105387687683105
loss at batch 94848: 2.083522319793701
loss at batch 94912: 2.5458762645721436
loss at batch 94976: 2.2807235717773438
loss at batch 95040: 2.379901647567749
loss at batch 95104: 2.457388401031494
loss at batch

loss at batch 106688: 2.1128644943237305
loss at batch 106752: 2.2824926376342773
loss at batch 106816: 1.9358755350112915
loss at batch 106880: 2.1067512035369873
loss at batch 106944: 2.314486503601074
loss at batch 107008: 2.286648988723755
loss at batch 107072: 2.3851191997528076
loss at batch 107136: 2.2403063774108887
loss at batch 107200: 2.407989025115967
loss at batch 107264: 2.127560615539551
loss at batch 107328: 2.0651285648345947
loss at batch 107392: 2.1717376708984375
loss at batch 107456: 2.3138022422790527
loss at batch 107520: 2.3063132762908936
loss at batch 107584: 2.1799936294555664
loss at batch 107648: 1.9328861236572266
loss at batch 107712: 2.312328577041626
loss at batch 107776: 2.0110974311828613
loss at batch 107840: 2.544574022293091
loss at batch 107904: 2.4224960803985596
loss at batch 107968: 2.0228137969970703
loss at batch 108032: 2.2161800861358643
loss at batch 108096: 2.0692782402038574
loss at batch 108160: 2.179335832595825
loss at batch 108224: 2

loss at batch 119616: 2.096625804901123
loss at batch 119680: 2.1915876865386963
loss at batch 119744: 2.0507328510284424
loss at batch 119808: 1.7179436683654785
loss at batch 119872: 1.9967347383499146
loss at batch 119936: 2.739015579223633
loss at batch 120000: 1.8611878156661987
loss at batch 120064: 1.9733201265335083
loss at batch 120128: 1.953633189201355
loss at batch 120192: 2.1981959342956543
loss at batch 120256: 2.1641950607299805
loss at batch 120320: 1.9544620513916016
loss at batch 120384: 1.7727874517440796
loss at batch 120448: 2.3924829959869385
loss at batch 120512: 1.8315553665161133
loss at batch 120576: 2.380122661590576
loss at batch 120640: 1.9153242111206055
loss at batch 120704: 2.3052408695220947
loss at batch 120768: 2.233654260635376
loss at batch 120832: 1.915078043937683
loss at batch 120896: 1.8165638446807861
loss at batch 120960: 1.8588647842407227
loss at batch 121024: 1.983673095703125
loss at batch 121088: 2.417133092880249
loss at batch 121152: 2.

loss at batch 132544: 2.203211784362793
loss at batch 132608: 2.270289421081543
loss at batch 132672: 1.9125767946243286
loss at batch 132736: 1.8594121932983398
loss at batch 132800: 1.983016014099121
loss at batch 132864: 1.71877920627594
loss at batch 132928: 2.2158360481262207
loss at batch 132992: 2.111052989959717
loss at batch 133056: 1.7610068321228027
loss at batch 133120: 1.794886589050293
loss at batch 133184: 1.9457365274429321
loss at batch 133248: 1.9893457889556885
loss at batch 133312: 1.8516536951065063
loss at batch 133376: 1.6202583312988281
loss at batch 133440: 1.915684461593628
loss at batch 133504: 2.0485661029815674
loss at batch 133568: 2.0632143020629883
loss at batch 133632: 2.0346806049346924
loss at batch 133696: 1.8127721548080444
loss at batch 133760: 2.0226187705993652
loss at batch 133824: 2.0884456634521484
loss at batch 133888: 1.9393126964569092
loss at batch 133952: 1.9001413583755493
loss at batch 134016: 1.8320338726043701
loss at batch 134080: 2.

loss at batch 145472: 2.134284019470215
loss at batch 145536: 1.558081030845642
loss at batch 145600: 1.8190181255340576
loss at batch 145664: 1.8624484539031982
loss at batch 145728: 1.7239240407943726
loss at batch 145792: 1.5920964479446411
loss at batch 145856: 2.508662223815918
loss at batch 145920: 1.9139999151229858
loss at batch 145984: 1.7023824453353882
loss at batch 146048: 1.7131294012069702
loss at batch 146112: 2.1901161670684814
loss at batch 146176: 1.702165961265564
loss at batch 146240: 1.7070679664611816
loss at batch 146304: 1.655532956123352
loss at batch 146368: 1.8542927503585815
loss at batch 146432: 1.9908299446105957
loss at batch 146496: 1.6107523441314697
loss at batch 146560: 1.647210955619812
loss at batch 146624: 1.5285272598266602
loss at batch 146688: 1.6775530576705933
loss at batch 146752: 1.9583704471588135
loss at batch 146816: 2.596186876296997
loss at batch 146880: 1.7386478185653687
loss at batch 146944: 2.367631196975708
loss at batch 147008: 2.

loss at batch 158400: 1.601049780845642
loss at batch 158464: 1.6048848628997803
loss at batch 158528: 2.237887382507324
loss at batch 158592: 2.1324288845062256
loss at batch 158656: 1.8737506866455078
loss at batch 158720: 2.015547037124634
loss at batch 158784: 1.7053449153900146
loss at batch 158848: 1.4822779893875122
loss at batch 158912: 1.9389257431030273
loss at batch 158976: 1.5690630674362183
loss at batch 159040: 1.483041763305664
loss at batch 159104: 1.8182108402252197
loss at batch 159168: 1.5927214622497559
loss at batch 159232: 2.195535182952881
loss at batch 159296: 2.0783376693725586
loss at batch 159360: 2.338857889175415
loss at batch 159424: 1.3193624019622803
loss at batch 159488: 1.6255762577056885
loss at batch 159552: 1.4662278890609741
loss at batch 159616: 1.8440414667129517
loss at batch 159680: 1.3137773275375366
loss at batch 159744: 1.8073910474777222
loss at batch 159808: 1.5212979316711426
loss at batch 159872: 1.5443812608718872
loss at batch 159936: 

loss at batch 171328: 2.2494688034057617
loss at batch 171392: 2.1444430351257324
loss at batch 171456: 1.688788890838623
loss at batch 171520: 1.535819411277771
loss at batch 171584: 1.8298050165176392
loss at batch 171648: 1.4294745922088623
loss at batch 171712: 1.6886062622070312
loss at batch 171776: 1.9141879081726074
loss at batch 171840: 1.64595365524292
loss at batch 171904: 1.5395158529281616
loss at batch 171968: 1.827136516571045
loss at batch 172032: 1.3391330242156982
loss at batch 172096: 1.8637131452560425
loss at batch 172160: 1.6028982400894165
loss at batch 172224: 1.497110366821289
loss at batch 172288: 1.5726869106292725
loss at batch 172352: 1.5347509384155273
loss at batch 172416: 1.7072923183441162
loss at batch 172480: 1.4709020853042603
loss at batch 172544: 1.6145151853561401
loss at batch 172608: 2.2525761127471924
loss at batch 172672: 1.5799810886383057
loss at batch 172736: 1.9884024858474731
loss at batch 172800: 1.9335769414901733
loss at batch 172864: 

loss at batch 184256: 1.6337573528289795
loss at batch 184320: 1.3104822635650635
loss at batch 184384: 1.3708858489990234
loss at batch 184448: 1.5700057744979858
loss at batch 184512: 1.5503966808319092
loss at batch 184576: 1.646257996559143
loss at batch 184640: 1.6448359489440918
loss at batch 184704: 1.593392014503479
loss at batch 184768: 1.3351023197174072
loss at batch 184832: 2.191152334213257
loss at batch 184896: 1.4691681861877441
loss at batch 184960: 2.429070234298706
loss at batch 185024: 1.4211026430130005
loss at batch 185088: 1.9899530410766602
loss at batch 185152: 1.8489041328430176
loss at batch 185216: 1.5203841924667358
loss at batch 185280: 1.7534191608428955
loss at batch 185344: 1.4878567457199097
loss at batch 185408: 1.450958490371704
loss at batch 185472: 1.2689753770828247
loss at batch 185536: 2.232100248336792
loss at batch 185600: 1.4098191261291504
loss at batch 185664: 2.07075834274292
loss at batch 185728: 1.4546904563903809
loss at batch 185792: 2.

loss at batch 11456: 2.6859188079833984
loss at batch 11520: 1.4592647552490234
loss at batch 11584: 1.6462191343307495
loss at batch 11648: 1.4173026084899902
loss at batch 11712: 1.6707391738891602
loss at batch 11776: 2.1782760620117188
loss at batch 11840: 1.6148937940597534
loss at batch 11904: 1.6781636476516724
loss at batch 11968: 1.910778522491455
loss at batch 12032: 1.3070650100708008
loss at batch 12096: 2.03185772895813
loss at batch 12160: 2.0678298473358154
loss at batch 12224: 1.408552885055542
loss at batch 12288: 1.2818409204483032
loss at batch 12352: 1.3263415098190308
loss at batch 12416: 1.4318764209747314
loss at batch 12480: 1.425008773803711
loss at batch 12544: 1.6922088861465454
loss at batch 12608: 2.13870906829834
loss at batch 12672: 1.4728742837905884
loss at batch 12736: 1.2752351760864258
loss at batch 12800: 1.548811674118042
loss at batch 12864: 1.4488898515701294
loss at batch 12928: 2.2755630016326904
loss at batch 12992: 1.5152071714401245
loss at 

loss at batch 24704: 1.3474465608596802
loss at batch 24768: 1.7657803297042847
loss at batch 24832: 1.4091780185699463
loss at batch 24896: 1.3706903457641602
loss at batch 24960: 1.499097466468811
loss at batch 25024: 1.295345664024353
loss at batch 25088: 1.6962002515792847
loss at batch 25152: 1.3895957469940186
loss at batch 25216: 1.2811427116394043
loss at batch 25280: 1.246339201927185
loss at batch 25344: 1.293062448501587
loss at batch 25408: 1.5050016641616821
loss at batch 25472: 1.4188240766525269
loss at batch 25536: 1.4606834650039673
loss at batch 25600: 1.4388186931610107
loss at batch 25664: 1.3177781105041504
loss at batch 25728: 1.785177230834961
loss at batch 25792: 1.5129873752593994
loss at batch 25856: 1.3706518411636353
loss at batch 25920: 2.4910531044006348
loss at batch 25984: 1.231197476387024
loss at batch 26048: 1.6261509656906128
loss at batch 26112: 1.51939058303833
loss at batch 26176: 1.4553358554840088
loss at batch 26240: 1.5003184080123901
loss at 

loss at batch 37952: 1.1831907033920288
loss at batch 38016: 1.4776099920272827
loss at batch 38080: 1.510644555091858
loss at batch 38144: 2.16936993598938
loss at batch 38208: 1.5486377477645874
loss at batch 38272: 1.2005616426467896
loss at batch 38336: 1.4735326766967773
loss at batch 38400: 1.6698955297470093
loss at batch 38464: 1.4448119401931763
loss at batch 38528: 1.4077224731445312
loss at batch 38592: 1.3983525037765503
loss at batch 38656: 1.2728406190872192
loss at batch 38720: 1.3324402570724487
loss at batch 38784: 1.3455169200897217
loss at batch 38848: 1.5780043601989746
loss at batch 38912: 1.3684868812561035
loss at batch 38976: 1.534103274345398
loss at batch 39040: 1.6419296264648438
loss at batch 39104: 1.5289592742919922
loss at batch 39168: 1.364574909210205
loss at batch 39232: 1.5066109895706177
loss at batch 39296: 1.451624870300293
loss at batch 39360: 1.3010140657424927
loss at batch 39424: 1.3691118955612183
loss at batch 39488: 1.9857182502746582
loss a

loss at batch 51200: 2.1877524852752686
loss at batch 51264: 1.3580201864242554
loss at batch 51328: 1.2527623176574707
loss at batch 51392: 1.2133424282073975
loss at batch 51456: 1.1486387252807617
loss at batch 51520: 1.2204854488372803
loss at batch 51584: 2.4216394424438477
loss at batch 51648: 1.610790491104126
loss at batch 51712: 1.5114918947219849
loss at batch 51776: 0.9596057534217834
loss at batch 51840: 1.315780758857727
loss at batch 51904: 1.2247287034988403
loss at batch 51968: 1.555325984954834
loss at batch 52032: 1.3328242301940918
loss at batch 52096: 1.6013100147247314
loss at batch 52160: 1.442153811454773
loss at batch 52224: 1.6122807264328003
loss at batch 52288: 1.3478187322616577
loss at batch 52352: 1.7553346157073975
loss at batch 52416: 1.3088953495025635
loss at batch 52480: 2.396237850189209
loss at batch 52544: 1.535386562347412
loss at batch 52608: 1.3732677698135376
loss at batch 52672: 1.096889853477478
loss at batch 52736: 1.2309954166412354
loss at

loss at batch 64448: 1.0541563034057617
loss at batch 64512: 1.0542476177215576
loss at batch 64576: 1.4600673913955688
loss at batch 64640: 1.3554391860961914
loss at batch 64704: 1.3420755863189697
loss at batch 64768: 1.3126546144485474
loss at batch 64832: 1.5440335273742676
loss at batch 64896: 1.6679800748825073
loss at batch 64960: 1.2868483066558838
loss at batch 65024: 1.3626399040222168
loss at batch 65088: 1.1217142343521118
loss at batch 65152: 1.021314024925232
loss at batch 65216: 1.5684555768966675
loss at batch 65280: 1.5149649381637573
loss at batch 65344: 1.1531236171722412
loss at batch 65408: 1.2148257493972778
loss at batch 65472: 1.2140721082687378
loss at batch 65536: 1.2466192245483398
loss at batch 65600: 1.4432514905929565
loss at batch 65664: 1.4813655614852905
loss at batch 65728: 1.1811723709106445
loss at batch 65792: 1.2218519449234009
loss at batch 65856: 1.2808061838150024
loss at batch 65920: 1.310281753540039
loss at batch 65984: 1.2263480424880981
lo

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

plt.xscale('log')
plt.plot(range(len(losses)), losses,'bx')

In [None]:
title_embedding.mh_attn

In [None]:
dummy_title = torch.LongTensor([1, 9, 17])
padded_title = torch.LongTensor([1, 9, 17, 0])

w_embeddings = nn.Embedding.from_pretrained(embeddings, freeze=False, sparse=True, padding_idx=0)
w_embeddings.weight.data[0] = torch.zeros(300)
mult_attn = MultiplicativeAttention()
add_attn = AdditiveSelfAttention(300)
mh_attn = MultiHeadAttention(300, 12, 25, 25)
nff = NonlinearFF(300, 25)

torch.set_printoptions(linewidth=120,threshold=100)
mult_attn.eval()
mh_attn.eval()
nff.eval()
add_attn.eval()
mult_attn.eval()
title_embedding.eval()

def transform(title, mask=None): 
    z = title
    s = mask
    print('title: {}'.format(title))
    
    q = k = v = w_embeddings(title)
    
    print('embedded title: \n{}'.format(q))
    print(q.shape)
    
    title, attn = mh_attn(q, k, v, mask)
    
    print('transformed title: \n{}'.format(title))
    print(title.shape)
    
    title = nff(title)
    
    print('title after nonlinear ff: \n {}'.format(title))
    print(title.shape)
    
    title, importance = add_attn(title, mask)
    
    print('importance of each word: \n {}'.format(importance))
    print(importance.shape)
    
    print('title after everything: \n {}'.format(title))
    print(title.shape)
    
    
    y = title_embedding(z, mask=s)
    print(y)
    print(y.shape)

transform(torch.LongTensor([[1, 5, 7, 2, 9, 8, 0, 0], [2, 5, 5, 0, 0, 0, 0, 0]]),        
                    mask= torch.ByteTensor([[0,0,0,0,0,0,1,1],[0,0,0,1,1,1,1,1]]))