In [1]:
import sys
import os
sys.path.append('../')

import gensim
from torch.utils.data import DataLoader

from dataset import LSTMAmazonRandomPrefix
from model import PadSequence, LSTMClassifier
from utils import process_df, init_random_seed
from tqdm import tqdm
import pandas as pd
import numpy as np
import random
import spacy
import torch

In [2]:
init_random_seed(9001)

# Modify following values according to desired flow

In [None]:
w2v_model_path = '' 
original_model_path = "bert-base-uncased"
model_dir = "../lstm_models/"
model_fname = ''
test_set_fname = ''
mode = 'random' # Can be 'random', 'complete' or 'subsets'

In [None]:
spacy_tokenizer = spacy.blank('en')

def reduce_title(title, max_len, tokenizer_spacy):
    tokens = tokenizer_spacy(title)
    return tokens[:max_len].text

In [None]:
test = pd.read_csv(test_set_fname, keep_default_na=False, na_values=['$$$__$$$'])
test['title'] = test['title'].apply(lambda x: " ".join(x.split()))
test = process_df(test, spacy_tokenizer) if mode in ['random', 'subsets'] else test
if mode == 'complete':
    test['tok_len'] = test['title'].apply(lambda x: len(spacy_tokenizer(x)))

In [None]:
w2v_model = gensim.models.KeyedVectors.load(w2v_model_path)
word2vec_vectors = list(w2v_model.vectors)
vocab_size = len(w2v_model.index_to_key)+2

embedding_dim = w2v_model.vector_size

word2vec_vectors.append(np.random.normal(scale=1.0, size=(embedding_dim,)))
word2vec_vectors.append(np.zeros(shape=(embedding_dim,)))
bidirection = True
hidden_size = 384
num_layers =1
dropout = 0.1



model = LSTMClassifier(vocab_size, embedding_dim, hidden_size, num_layers, bidirection, dropout,
                                test['label'].nunique(), pad_token_id=len(word2vec_vectors)-1)
unk_token_id = len(word2vec_vectors) - 2
pad_token_id = len(word2vec_vectors) - 1

del word2vec_vectors

In [None]:
from metrics import get_all_classes_stats, accuracy, recall, precision, rank_predictions_metrics


original_token_lengths = [12,]
prefix_lengths = list(range(1,7))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = os.path.join(model_dir,model_fname)
model.load_state_dict(torch.load(model_path))
model = model.to(device)
model.eval()
stats = {}
for length in original_token_lengths:
    stats[length] = {}
    print("working on length",length)
    for prefix_length in prefix_lengths:
        stats[length][prefix_length] = {}
        tmp = test.copy(deep=True)
        tmp = tmp[tmp.tok_len==length].reset_index(drop=True)
        tmp['title'] = tmp['title'].apply(lambda x: reduce_title(x,prefix_length, spacy_tokenizer))
        
        test_dl = DataLoader(LSTMAmazonRandomPrefix(tmp, "title", "label", w2v_model.key_to_index,
                                            unk_token_id=unk_token_id,random_=False),
                          batch_size=64, num_workers=4,
                          collate_fn=PadSequence(pad_token_id=pad_token_id),shuffle=False)
        all_labels = []
        all_preds_scores = []
        for step,batch in enumerate(tqdm(test_dl)):
            inp_ids,lengths, labels = batch
            inp_ids = inp_ids.to(device)
            lengths = lengths.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(inp_ids, lengths)
                all_preds_scores.extend(list(outputs.cpu().numpy()))
                all_labels.extend(list(labels.cpu().numpy()))

        all_preds_scores, all_labels = np.array(all_preds_scores), np.array(all_labels)
        all_preds = np.argmax(all_preds_scores,axis=1)
        
        precision_stats = get_all_classes_stats(precision,all_preds,all_labels)
        macro_precision = np.mean([precision_stats[i] for i in precision_stats])*100
        recall_stats = get_all_classes_stats(recall,all_preds,all_labels)
        macro_recall = np.mean([recall_stats[i] for i in recall_stats])*100
        accuracy_value = accuracy(all_preds,all_labels)*100
        
        stats[length][prefix_length]['MACRO_PRECISION'] = macro_precision
        stats[length][prefix_length]['MACRO_RECALL'] = macro_recall
        stats[length][prefix_length]['ACCURACY'] = accuracy_value
        
        for k in [3,5]:
            hits, hit_rate = rank_predictions_metrics(all_preds_scores, all_labels, k)
            stats[length][prefix_length][f'HITS@{k}'] = hit_rate