In [1]:
from tqdm import tqdm

import numpy as np
import numpy.random as random

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.nn.utils.rnn import pad_sequence

import io

In [None]:
def load_embeddings(fname, get_embeddings=True, get_w2i=False, get_i2w=False, skip_first_line=True):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    
    if skip_first_line:
        fin.readline()
    
    num_embeddings = 0
    
    word2idx = {}
    idx2word = {}

    embeddings = []

    for line in fin:
        line = line.rstrip().split(' ')
        
        if get_w2i:
            word2idx[line[0]] = num_embeddings
        if get_i2w:
            idx2word[num_embeddings] = line[0]
        if get_embeddings:       
            embeddings.append([float(num) for num in line[1:]])
        
        num_embeddings += 1
        
        
    return torch.FloatTensor(embeddings), word2idx, idx2word

In [None]:
word2idx = load_embeddings('../embeddings/wiki-news-300d-1M.vec', get_embeddings=False, get_w2i=True)[1]

In [5]:
class MultiplicativeAttention(nn.Module):
      
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        
        attn = torch.matmul(q , k.transpose(-2, -1) / math.sqrt(q.size(-1)))
        
        if mask is not None:
            attn = attn.masked_fill(mask.unsqueeze(1) == 1, -1e9)
        
        attn = self.dropout(F.softmax(attn, dim=-1))        
        res = torch.matmul(attn, v)

        return res, attn

class AdditiveAttention(nn.Module):
    
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        
        self.w = nn.Linear(d_model, d_model)
        self.q = torch.nn.Parameter(torch.FloatTensor(d_model).uniform_(-0.1, 0.1))
    
    def forward(self, x, mask=None):
        attn = torch.tanh(self.dropout(self.w(x)))        
        attn = torch.matmul(attn, self.q)
        
        if mask is not None:
            attn = attn.masked_fill(mask == 1, -1e9)
        
        attn = self.dropout(F.softmax(attn, dim=-1))

        
        res = torch.einsum('ijk, ij->ik', x, attn)
        return res, attn

    
class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_model, num_heads, d_qk, d_v, track_agreement=False, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_qk = d_qk
        self.d_v = d_v
        
        self.num_heads = num_heads
        
        self.dropout = nn.Dropout(dropout)
        
        self.w_q = nn.Linear(d_model, num_heads * d_qk, bias=False)
        self.w_k = nn.Linear(d_model, num_heads * d_qk, bias=False)
        self.w_v = nn.Linear(d_model, num_heads * d_v, bias=False)
        
        self.w_fc = nn.Linear(num_heads * d_v, d_model, bias=False)
        
        self.attention = MultiplicativeAttention(dropout=dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
    
        self.track_agreement = track_agreement
        self.v_agreement = 0

    def forward(self, q, k, v, mask=None):     
        batch_size = q.shape[0]
        seq_size = q.shape[1]
        
        q_proj = self.w_q(q).view(q.shape[0], q.shape[1], self.num_heads, self.d_qk)
        k_proj = self.w_k(k).view(k.shape[0], k.shape[1], self.num_heads, self.d_qk)
        v_proj = self.w_v(v).view(v.shape[0], v.shape[1], self.num_heads, self.d_v) 

        if self.track_agreement:
            self.v_agreement += torch.einsum('bshd, bsnd->', F.normalize(v_proj, dim=3), F.normalize(v_proj, dim=3)) / self.num_heads**2

        if mask is None:
            q, attn = self.attention(q_proj.transpose(1, 2), k_proj.transpose(1, 2), v_proj.transpose(1, 2))
        else:
            q, attn = self.attention(q_proj.transpose(1, 2), k_proj.transpose(1, 2), v_proj.transpose(1, 2), mask.unsqueeze(1))
        
        q = q.transpose(1, 2).contiguous()
        q = q.view(batch_size, seq_size, -1)

        q = self.dropout(self.w_fc(q))

        q = self.layer_norm(q)
        
        return q, attn

    def clear_agreement(self):
        self.v_agreement = 0

class NonlinearFF(nn.Module):
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)
        self.w_2 = nn.Linear(d_hid, d_in)
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)

        x = self.layer_norm(x)

        return x
    
class TitleEmbedding(nn.Module):
    def __init__(self, num_embeddings, d_model, num_heads, d_qk, d_v, d_hid=None, embeddings=None, track_agreement=False, padding_idx=0, dropout=0.1):
        super().__init__()

        if embeddings is not None:
            self.embeddings = nn.Embedding.from_pretrained(embeddings, freeze=False, sparse=True, padding_idx=padding_idx)
        else:
            self.embeddings = nn.Embedding(num_embeddings, d_model, sparse=True, padding_idx=0)
            
        self.mh_attn = MultiHeadAttention(d_model, num_heads, d_qk, d_v, track_agreement=track_agreement, dropout=dropout)
        self.nff = NonlinearFF(d_model, d_hid if d_hid is not None else d_model * 4, dropout=dropout)
        self.add_attn = AdditiveAttention(d_model, dropout=dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        
        self.padding_idx = padding_idx
        
    def forward(self, title):    
        mask = (title == self.padding_idx).byte()
        
        q = k = v = self.embeddings(title)
        title, attn = self.mh_attn(q, k ,v, mask=mask)
        
        title = self.nff(title)
        title, add_attn = self.add_attn(title, mask=mask)
        
        title = self.layer_norm(title)
        
        return title
    
    def load_embeddings(embeddings):
        self.embeddings = nn.Embedding.from_pretrained(embeddings, freeze=False, sparse=True)
    
    


In [None]:
device = torch.device('cuda:0')

title_embedding = torch.load(r'C:\Users\Tadija\Desktop\wikipedia\tensor\model_oooo.pt').to(device)

In [None]:
class MultipleOptimizer:
    def __init__(self, *op):
        self.optimizers = op

    def zero_grad(self):
        for op in self.optimizers:
            op.zero_grad()

    def step(self):
        for op in self.optimizers:
            op.step()

sparse_params = []
dense_params = []

for name, param in title_embedding.named_parameters():
    if name == 'embeddings.weight':
        sparse_params.append(param)
    else:
        dense_params.append(param)
        
opt_dense = torch.optim.Adam(dense_params, lr=1e-3)
opt_sparse = torch.optim.SGD(sparse_params, lr=1e-3)

optimizer = MultipleOptimizer(opt_sparse, opt_dense)

In [None]:
losses = []

In [None]:
num_epochs = 5

batch_size = 128
num_samples = 32

d_model = title_embedding.mh_attn.d_model
reg_coeff = 1.

In [None]:
title_embedding.to(device)

In [None]:

avg_batch_loss = 0


while True:
    for dataset_part in range(0, 7):

        dataset = torch.load(r'C:\Users\Tadija\Desktop\wikipedia\tensor\dataset{}.pt'.format(dataset_part))

        print('loaded dataset part {}'.format(dataset_part + 1))

        len_dataset = len(dataset.tensors[1])
        len_sequence = dataset.tensors[1].shape[1]

        train_loader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=batch_size)        

        for idx, batch in enumerate(train_loader):

            title_embedding.mh_attn.clear_agreement()

            x, y = batch[0].to(device), batch[1].to(device)

            sz = x.shape[0]

            optimizer.zero_grad()

            sample_indices = torch.FloatTensor(sz * num_samples).uniform_(0, len_dataset - 1).long()        
            sample_indices, tmp = torch.broadcast_tensors(sample_indices.unsqueeze(1), 
                                                      torch.arange(sz * num_samples * len_sequence)
                                                      .view(sz * num_samples, len_sequence))

            n = torch.gather(dataset.tensors[1], 0, sample_indices).to(device)

            x = title_embedding(x)
            y = title_embedding(y)
            n = title_embedding(n)

            target_loss = F.cosine_embedding_loss(x, y, torch.Tensor([1]).to(device), margin=0.5)

            x = torch.broadcast_tensors(x.unsqueeze(1), n.view(sz, num_samples, d_model))[0].flatten(0,1)

            noise_loss = F.cosine_embedding_loss(x, n, torch.Tensor([-1]).to(device), margin=0.5, reduction='none')
            noise_loss = noise_loss.view(sz, num_samples, 1).sum(1).mean()    

            loss = target_loss + noise_loss + reg_coeff * title_embedding.mh_attn.v_agreement

            loss.backward()

            title_embedding.mh_attn.clear_agreement()

            optimizer.step()

            avg_batch_loss += loss.item()

            if (idx + 1) % 1000 == 0:
                print('avg loss at batch {}: {}'.format((idx+1), avg_batch_loss / batch_size))
                losses.append(avg_batch_loss / (1000 * (1 + num_samples)) )
                avg_batch_loss = 0

            if (idx + 1) == 10000:
                break

    #torch.save(title_embedding.to('cpu'), r'C:\Users\Tadija\Desktop\wikipedia\tensor\model_{}.pt'.format(dataset_part))
    #title_embedding.to(device)

In [None]:
#del title_embedding
#title_embedding = torch.load(r'C:\Users\Tadija\Desktop\wikipedia\tensor\model_ooo.pt', map_location='cuda')

In [None]:
torch.save(title_embedding.to('cpu'), r'C:\Users\Tadija\Desktop\wikipedia\tensor\model_ft.pt'.format(0))
title_embedding.to(device)

In [None]:
torch.save(title_embedding.to('cpu'), r'C:\Users\Tadija\Desktop\wikipedia\tensor\model_0_owo.pt')

In [None]:
title_embedding.to('cpu')

In [None]:
torch.save(title_embedding, r'C:\Users\Tadija\Desktop\wikipedia\tensor\model_test.pt')

In [None]:
import matplotlib.pyplot as plt

plt.yscale('log')
plt.plot(range(len(losses)), losses,'bx')

In [None]:
dummy_title = torch.LongTensor([1, 9, 17])
padded_title = torch.LongTensor([1, 9, 17, 0])

w_embeddings = nn.Embedding.from_pretrained(embeddings, freeze=False, sparse=True, padding_idx=0)
w_embeddings.weight.data[0] = torch.zeros(300)
mult_attn = MultiplicativeAttention()
add_attn = AdditiveSelfAttention(300)
mh_attn = MultiHeadAttention(300, 12, 25, 25)
nff = NonlinearFF(300, 25)

torch.set_printoptions(linewidth=120,threshold=100)
mult_attn.eval()
mh_attn.eval()
nff.eval()
add_attn.eval()
mult_attn.eval()
title_embedding.eval()

def transform(title, mask=None): 
    z = title
    s = mask
    print('title: {}'.format(title))
    
    q = k = v = w_embeddings(title)
    
    print('embedded title: \n{}'.format(q))
    print(q.shape)
    
    title, attn = mh_attn(q, k, v, mask)
    
    print('transformed title: \n{}'.format(title))
    print(title.shape)
    
    title = nff(title)
    
    print('title after nonlinear ff: \n {}'.format(title))
    print(title.shape)
    
    title, importance = add_attn(title, mask)
    
    print('importance of each word: \n {}'.format(importance))
    print(importance.shape)
    
    print('title after everything: \n {}'.format(title))
    print(title.shape)
    
    
    y = title_embedding(z, mask=s)
    print(y)
    print(y.shape)

transform(torch.LongTensor([[1, 5, 7, 2, 9, 8, 0, 0], [2, 5, 5, 0, 0, 0, 0, 0]]),        
                    mask= torch.ByteTensor([[0,0,0,0,0,0,1,1],[0,0,0,1,1,1,1,1]]))

In [3]:
a = torch.load('../models/model1.pt')

In [4]:
torch.save(a.state_dict(), 'model.pt')

In [6]:
b = TitleEmbedding(999994, 300, 12, 25, 25)

In [7]:
model_state_dict = torch.load('model.pt')
    

In [None]:
model_state_dict()

In [8]:
b.load_state_dict(model_state_dict)

<All keys matched successfully>

In [10]:
a = b.state_dict()

In [11]:
a['num_heads'] = 12

In [12]:
torch.save(a, 'model.pt')

In [15]:
a.keys()

odict_keys(['embeddings.weight', 'mh_attn.w_q.weight', 'mh_attn.w_k.weight', 'mh_attn.w_v.weight', 'mh_attn.w_fc.weight', 'mh_attn.layer_norm.weight', 'mh_attn.layer_norm.bias', 'nff.w_1.weight', 'nff.w_1.bias', 'nff.w_2.weight', 'nff.w_2.bias', 'nff.layer_norm.weight', 'nff.layer_norm.bias', 'add_attn.q', 'add_attn.layer_norm.weight', 'add_attn.layer_norm.bias', 'add_attn.w.weight', 'add_attn.w.bias', 'layer_norm.weight', 'layer_norm.bias', 'num_heads'])