In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from prettytable import PrettyTable
from torchtext import data
import torchtext
from torchtext.data import Field,LabelField
import pdb
import pandas as pd
import torch.optim as optim
import torch
import torch.nn as nn
import torch.optim as optim
import time
import torch.nn.functional as f
from prettytable import PrettyTable
import logging
import copy
from utils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### RANKER

In [None]:
class AverageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
class Timer(object):

    def __init__(self):
        self.running = True
        self.total = 0
        self.start = time.time()

    def reset(self):
        self.running = True
        self.total = 0
        self.start = time.time()
        return self

    def resume(self):
        if not self.running:
            self.running = True
            self.start = time.time()
        return self

    def stop(self):
        if self.running:
            self.running = False
            self.total += time.time() - self.start
        return self

    def time(self):
        if self.running:
            return self.total + time.time() - self.start
        return self.total

class Ranker(object):
    
    def __init__(self,src_dict, unk="<UNK>", embedding_dim=128, hidden_dim=64):
        self.src_dict = src_dict
        self.unk = unk
        self.word_to_ix = src_dict
        self.updates = 0

        self.network = DSSMModel(
                        embedding_dim = embedding_dim,
                        hidden_dim = hidden_dim,
                        vocab_size = len(self.word_to_ix),
                        tagset_size = 128
        )
        self.criterion = self.compute_loss
        
    @staticmethod
    def compute_loss(predictions, target):
        predictions = f.log_softmax(predictions, dim=-1)
        loss = -(predictions * target).sum(1)
        return loss.mean()
    
    
    def layer_wise_parameters(self):
        table = PrettyTable()
        table.field_names = ["Layer Name", "Output Shape", "Param #"]
        table.align["Layer Name"] = "l"
        table.align["Output Shape"] = "r"
        table.align["Param #"] = "r"
        for name, parameters in self.network.named_parameters():
            if parameters.requires_grad:
                table.add_row([name, str(list(parameters.shape)), parameters.numel()])
        return table

#     def load_embeddings

    def init_optimizer(self, state_dict=None, fix_embeddings=False,
                      learning_rate = 0.001, momentum=0, weight_decay=0):
        
        if fix_embeddings:
            for p in self.network.word_embeddings.parameters():
                p.requires_grad = False

        parameters = [p for p in self.network.parameters() if p.requires_grad]
        
        
        self.optimizer = optim.SGD(parameters, learning_rate,
                                   momentum=momentum,
                                   weight_decay=weight_decay)
        

        if state_dict is not None:
            self.optimizer.load_state_dict(state_dict)
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)

    # --------------------------------------------------------------------------
    # Learning
    # --------------------------------------------------------------------------

    def update(self, rows, grad_clipping = 10):
        """Forward a batch of examples; step the optimizer to update weights."""
        if not self.optimizer:
            raise RuntimeError('No optimizer set.')
        # Train mode
        self.network.train()
        queries = rows[0] #makeEmbeddingTensors(rows[0],self.word_to_ix,self.unk)
        documents = rows[1] #makeEmbeddingTensors(rows[1],self.word_to_ix,self.unk)
        
        que_len = 0#ex['que_len']
        doc_len = 0#ex['doc_len']
        
        labels = rows[2] #torch.tensor(rows[2], dtype=torch.long).to(device)
        

        # Run forward
        scores = self.network(queries, documents)
        loss = self.criterion(scores, labels)

        # Clear gradients and run backward
        self.optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm(self.network.parameters(),
                                      grad_clipping)

        # Update parameters
        self.optimizer.step()
        self.updates += 1

        return loss

    # --------------------------------------------------------------------------
    # Prediction
    # --------------------------------------------------------------------------

    def predict(self, ex):
        
        # Eval mode
        self.network.eval()

        documents = ex['doc_rep']
        queries = ex['que_rep']
        que_len = ex['que_len']
        doc_len = ex['doc_len']
        
        #CUDAACUDAACUDAACUDAACUDAACUDAACUDAACUDAACUDAACUDAA

        # Run forward
        scores = self.network(queries, que_len, documents, doc_len)
        scores = f.softmax(scores, dim=-1)

        return scores

    # --------------------------------------------------------------------------
    # Saving and loading
    # --------------------------------------------------------------------------

    def save(self, filename):
        
        network = self.network
        state_dict = copy.copy(network.state_dict())
        if 'fixed_embedding' in state_dict:
            state_dict.pop('fixed_embedding')
        params = {
            'state_dict': state_dict,
            'src_dict': self.src_dict,
#             'arags': self.arags,
        }
        try:
            torch.save(params, filename)
        except BaseException:
            logger.warning('WARN: Saving failed... continuing anyway.')

    def checkpoint(self, filename, epoch):
        network = self.network
        params = {
            'state_dict': network.state_dict(),
            'src_dict': self.src_dict,
#             'arags': self.arags,
            'epoch': epoch,
            'optimizer': self.optimizer.state_dict(),
        }
        try:
            torch.save(params, filename)
        except BaseException:
            print('WARN: Saving failed... continuing anyway.')

    @staticmethod
    def load(filename):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(
            filename, map_location=lambda storage, loc: storage
        )
        src_dict = saved_params['src_dict']
        state_dict = saved_params['state_dict']
#         arags = saved_params['arags']
#         if new_arags:
#             arags = override_model_arags(arags, new_arags)
        return Ranker(src_dict, state_dict)

    @staticmethod
    def load_checkpoint(filename):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(
            filename, map_location=lambda storage, loc: storage
        )
        src_dict = saved_params['src_dict']
        state_dict = saved_params['state_dict']
        epoch = saved_params['epoch']
        optimizer = saved_params['optimizer']
#         arags = saved_params['arags']
        model = Ranker(src_dict, state_dict)
        model.init_optimizer(optimizer)
        return model, epoch

### DSSM

In [None]:
class DSSMModel(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, dropout=0.5):
        super(DSSMModel, self).__init__()

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=1).to(device)
        
        
        self.dropout_embedding = nn.Dropout(p=dropout).to(device)

        self.mlp_search_query = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, tagset_size),
            nn.Tanh()
        ).to(device)
        
        self.mlp_details = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, tagset_size),
            nn.Tanh()
        ).to(device)

    def forward(self, batch_queries, batch_docs):
        BATCH_SIZE = 1
        
        batch_queries = batch_queries.reshape(BATCH_SIZE,batch_queries.shape[0])
        batch_docs = batch_docs.reshape(BATCH_SIZE,batch_docs.shape[0],batch_docs.shape[1])
        
        assert batch_queries.shape[0] == batch_docs.shape[0]
        
        batch_size = batch_queries.shape[0]
        qlen = batch_queries.shape[1]
        num_docs, dlen = batch_docs.shape[1], batch_docs.shape[2]

        # embed query
        embedded_queries = self.word_embeddings(batch_queries.unsqueeze(2))
        embedded_queries = self.dropout_embedding(embedded_queries)
        embedded_queries = embedded_queries.max(1)[0]  # max-pooling

        # embed document
        doc_rep = batch_docs.view(batch_size * num_docs, dlen).unsqueeze(2)
        embedded_docs = self.word_embeddings(doc_rep)
        embedded_docs = self.dropout_embedding(embedded_docs)
        embedded_docs = embedded_docs.max(1)[0]  # max-pooling
        embedded_docs = embedded_docs.view(batch_size, num_docs, -1)

        query_rep = self.mlp_search_query(embedded_queries)
        doc_rep = self.mlp_details(embedded_docs)
#         query_rep = query_rep.unsqueeze(1).expand(*doc_rep.size())
        scores = f.cosine_similarity(query_rep, doc_rep, dim=2)

        return scores
    

## TRAIN

In [None]:
def train(model,train_itr, epochs=10, batch_size=32, 
             data_workers=5,lr_decay=0.95):
    
    for i in range(epochs):
        
        ce_loss = AverageMeter()
        epoch_time = Timer()
        
        model.optimizer.param_groups[0]['lr'] = \
        model.optimizer.param_groups[0]['lr'] * lr_decay

#         pbar = tqdm(train_itr)
#         pbar.set_description("%s" % 'Epoch = %d [ce_loss = x.xx]' % i)
        
        for idx,rows in tqdm(enumerate(train_itr)):
            net_loss = model.update(rows)
            ce_loss.update(net_loss.item(), 1)

            log_info = 'Epoch = %d [ce_loss = %.2f]' % \
                       (i, ce_loss.avg)

#             pbar.set_description("%s" % log_info)
            torch.cuda.empty_cache()

        # Checkpoint
        model.checkpoint('ranker.pt' + '.checkpoint', i + 1)
        

## Load

In [None]:
print('Loading Train Data...')
main_data = pd.read_csv('data/train_dssm2.csv')

src_dict = makeVocabDict(main_data,['product_title','search_term','brand','product_description'])
tabData = makeTensorData(main_data,src_dict)


## CALL

In [None]:
model = Ranker(src_dict)
model.init_optimizer()

train(model,tabData)