In [35]:
# Automatic reload of local libraries
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [36]:
fqueries = '../ruwikIR/processed_queries.csv'
fdocs = '../ruwikIR/processed_documents.csv'
fqrels = '../ruwikIR/qrels'

emb_file = '/home/mrim/data/embeddings/cc.ru.300.bin'

In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import fasttext

model = fasttext.load_model(emb_file)

def embedding_matrix(text, max_len):
    words = text.split()
    matrix = np.empty(())
    dim = model.get_dimension()
    matrix = np.zeros((max_len, dim))
    for i in range(min(len(words), max_len)):
        matrix[i] = model[words[i]]
    return matrix

def build_emb_input(batch):
    output = []
    for triple in batch:
        q, d1, d2 = triple
        q_m = embedding_matrix(q, max_len = 10)
        d1_m = embedding_matrix(d1, max_len = 200)
        d2_m = embedding_matrix(d2, max_len = 200)
        output.append(np.array([q_m, d1_m, d2_m]))
    return np.asarray(output)

def reshape_4d(tensor):
    return torch.from_numpy(tensor).float().view(1, tensor.shape[1], 1, tensor.shape[0])




In [38]:
class Autoencoder(nn.Module):
    def __init__(self, layer_size, dropout_prob=0.6, ):
        super().__init__()
        self.layer_size = layer_size
        self.fc = nn.ModuleList([])
        for i in range(len(layer_size)-1):
            self.fc.append(nn.Conv2d(layer_size[i], layer_size[i+1], (1, 5 if i == 0 else 1)))
        self.dropout = nn.Dropout(p=dropout_prob, inplace=False)

    def forward(self, x):
        for i in range(len(self.fc)):
            x = self.dropout(F.relu(self.fc[i](x)))
        x=torch.mean(x, 3, keepdim=True)
        return x
    

# TODO:
2. Интегрировать tensorboardx

In [47]:
import torch.optim as optim
from utils import ModelInputGenerator

mi_generator = ModelInputGenerator(fdocs, fqueries, fqrels)
batch_num = 1
autoencoder = Autoencoder([300, 100, 5000])
criterion = nn.MarginRankingLoss(margin=1.0)
optimizer = optim.SGD(autoencoder.parameters(), lr=0.001, momentum=0.9)
reg_lambda = 10e-7 


for epoch in range(1):  # loop over the dataset multiple times
    running_loss = 0.0
    mi_generator.reset()
    for b in range(batch_num):
        batch = mi_generator.generate_batch(size=4)
        out_batch = build_emb_input(batch)
        for i in range(out_batch.shape[0]):
            # get the inputs; data is a list of [inputs, labels]
            query, d1, d2 = out_batch[i]
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize

            q_out = autoencoder(reshape_4d(query))
            d1_out = autoencoder(reshape_4d(d1))
            d2_out = autoencoder(reshape_4d(d2))
            
            reg_term = torch.cat((q_out, d1_out, d2_out), dim=1).sum(dim=1, keepdim=True)
            x1 = (q_out * d1_out).sum(dim=1, keepdim=True)
            x2 = (q_out * d2_out).sum(dim=1, keepdim=True)

            target = torch.ones(1)
            loss = criterion(x1, x2, target) + reg_lambda * reg_term
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 200 == 199:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

print('Finished Training')

Preprocessing data started...
Finished.
Finished Training


In [48]:
def zeros(x):
    return len([i for i, e in enumerate(x) if e == 0])

def get_zeros(x):
    q, d1, d2 = x
    qa = autoencoder(reshape_4d(q)).view(-1)
    d1a = autoencoder(reshape_4d(d1)).view(-1)
    d2a = autoencoder(reshape_4d(d2)).view(-1)
    return zeros(qa), zeros(d1a), zeros(d2a)

In [49]:
mi_generator.reset(4)
batch = mi_generator.generate_batch(size=20)
out_batch = build_emb_input(batch)
    
for x in out_batch:
    q, d1, d2 = get_zeros(x)
    print("Iteration #"+str(i)+ ":")
    print("Zeros in query: ", q)
    print("Zeros in doc1: ", d1)
    print("Zeros in doc2: ", d2)

Iteration #3:
Zeros in query:  990
Zeros in doc1:  796
Zeros in doc2:  596
Iteration #3:
Zeros in query:  987
Zeros in doc1:  817
Zeros in doc2:  796
Iteration #3:
Zeros in query:  989
Zeros in doc1:  819
Zeros in doc2:  716
Iteration #3:
Zeros in query:  995
Zeros in doc1:  823
Zeros in doc2:  782
Iteration #3:
Zeros in query:  992
Zeros in doc1:  792
Zeros in doc2:  800
Iteration #3:
Zeros in query:  993
Zeros in doc1:  800
Zeros in doc2:  800
Iteration #3:
Zeros in query:  993
Zeros in doc1:  791
Zeros in doc2:  826
Iteration #3:
Zeros in query:  994
Zeros in doc1:  824
Zeros in doc2:  768
Iteration #3:
Zeros in query:  997
Zeros in doc1:  792
Zeros in doc2:  836
Iteration #3:
Zeros in query:  992
Zeros in doc1:  795
Zeros in doc2:  819
Iteration #3:
Zeros in query:  994
Zeros in doc1:  806
Zeros in doc2:  768
Iteration #3:
Zeros in query:  989
Zeros in doc1:  807
Zeros in doc2:  804
Iteration #3:
Zeros in query:  990
Zeros in doc1:  830
Zeros in doc2:  797
Iteration #3:
Zeros in qu

In [34]:
torch.save(autoencoder.state_dict(), './autoencoder.pth')