In [1]:
import os 
import re 
import pandas as pd
import numpy as np 
import pickle 
from tokenizers import BertWordPieceTokenizer
import sys
sys.path.append('../scripts/')
from TwoClassHeadClassificationTransformer import *
from ClassificationDatasetFromDict import *
import pickle 
import torch 
import torch.nn as nn 


In [28]:

SEED = 3007
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

global model, tokenizer, idx2class, idx2subclass, class2names, device, subclass2names

device = 'cpu'

def save_pickle(obj, filepath):
    with open(filepath, 'wb') as fp:
        pickle.dump(obj, fp)

def load_pickle(filepath):
    with open(filepath, 'rb') as fp:
        return pickle.load(fp)        

class2names = {
    "DESC": "DESCRIPTION",
    "ENTY": "ENTITY",
    "ABBR": "ABBREVIATION",
    "HUM": "HUMAN",
    "NUM": "NUMERIC",
    "LOC": "LOCATION"
}

subclass2names = {
    'manner': 'manner',
    'cremat': 'creative',
    'animal': 'animal',
    'exp': 'expression abbreviated',
    'ind': 'individual',
    'gr': 'group',
    'title': 'title',
    'def': 'definition',
    'date': 'date',
    'reason': 'reason',
    'event': 'event',
    'state': 'state',
    'desc': 'description',
    'count': 'count',
    'letter': 'letter',
    'religion': 'religion',
    'food': 'food',
    'country': 'country',
    'color': 'color',
    'termeq': 'term',
    'body': 'body',
    'dismed': 'diseases and medicine',
    'mount': 'mountains',
    'money': 'money',
    'product': 'product',
    'period': 'period',
    'substance': 'substance',
    'city': 'city',
    'sport': 'sport',
    'plant': 'plant',
    'techmeth': 'techniques and methods',
    'volsize': 'size, area and volume',
    'instru': 'musical instrument',
    'abb': 'abbreviation',
    'speed': 'speed',
    'word': 'word',
    'lang': 'languages',
    'perc': 'percentage or fractions',
    'code': 'code (postcodes or other codes)',
    'dist': 'distance',
    'temp': 'temperature',
    'symbol': 'symbol',
    'ord': 'order or ranks',
    'veh': 'vehicles',
    'weight': 'weight',
    'currency': 'currency',
    'other': 'other'
}



idx2class = load_pickle('../data/idx2class.pkl')
idx2subclass = load_pickle('../data/idx2subclass.pkl')

tokenizer = BertWordPieceTokenizer('../data/bert-word-piece-custom-wikitext-vocab-10k-vocab.txt', lowercase = True, strip_accents = True)


In [29]:
save_pickle(subclass2names, '../data/subclass2names.pkl')
save_pickle(class2names, '../data/class2names.pkl')

In [3]:

vocab_size = tokenizer.get_vocab_size()
pad_id = 0
CLS_label_id = 2
num_class_heads = 2
lst_num_cat_in_classes = [6, 47]
seq_len = 100
batch_size = 256
num_workers = 3

model = TwoClassHeadClassificationTransformer(
    vocab_size=vocab_size, pad_id=pad_id, CLS_label_id=CLS_label_id,
    num_class_heads=num_class_heads, 
    lst_num_cat_in_classes=lst_num_cat_in_classes, num_pos=seq_len
)

model_dict = torch.load('../models/classification_model_state_dict_best.pth', map_location = device)
model.load_state_dict(model_dict['model_dict'])
model = model.to(device)
model = model.eval()

print(f'''
Model saved at: {model_dict['epoch']}
    Accuracy Class: {model_dict['accuracy_class']}
    Accuracy Subclass: {model_dict['accuracy_subclass']}
''')



Model saved at: 165
    Accuracy Class: 0.9926846636887422
    Accuracy Subclass: 0.9776008286398653



In [23]:
def predictQuestionClassSubclass(text):

    tokens = torch.FloatTensor(tokenizer.encode(text).ids).unsqueeze(0).to(device)
    cls_, subcls = model(tokens)
    clsIdx = cls_.max(1)[-1].item()
    subclsIdx = subcls.max(1)[-1].item()

    return {
        "class": class2names[idx2class[clsIdx]],
        "subclass": subclass2names[idx2subclass[subclsIdx]]
    }



In [25]:
predictQuestionClassSubclass('what is time?')

{'class': 'DESCRIPTION', 'subclass': 'definition'}

In [24]:
predictQuestionClassSubclass('who was the first man to land on moon?')

{'class': 'HUMAN', 'subclass': 'group'}

In [27]:
subclass2names = {
    'manner': 'manner',
    'cremat': 'creative',
    'animal': 'animal',
    'exp': 'expression abbreviated',
    'ind': 'individual',
    'gr': 'group',
    'title': 'title',
    'def': 'definition',
    'date': 'date',
    'reason': 'reason',
    'event': 'event',
    'state': 'state',
    'desc': 'description',
    'count': 'count',
    'letter': 'letter',
    'religion': 'religion',
    'food': 'food',
    'country': 'country',
    'color': 'color',
    'termeq': 'term',
    'body': 'body',
    'dismed': 'diseases and medicine',
    'mount': 'mountains',
    'money': 'money',
    'product': 'product',
    'period': 'period',
    'substance': 'substance',
    'city': 'city',
    'sport': 'sport',
    'plant': 'plant',
    'techmeth': 'techniques and methods',
    'volsize': 'size, area and volume',
    'instru': 'musical instrument',
    'abb': 'abbreviation',
    'speed': 'speed',
    'word': 'word',
    'lang': 'languages',
    'perc': 'percentage or fractions',
    'code': 'code (postcodes or other codes)',
    'dist': 'distance',
    'temp': 'temperature',
    'symbol': 'symbol',
    'ord': 'order or ranks',
    'veh': 'vehicles',
    'weight': 'weight',
    'currency': 'currency',
    'other': 'other'
}

In [26]:
idx2subclass

{0: 'manner',
 1: 'cremat',
 2: 'animal',
 3: 'exp',
 4: 'ind',
 5: 'gr',
 6: 'title',
 7: 'def',
 8: 'date',
 9: 'reason',
 10: 'event',
 11: 'state',
 12: 'desc',
 13: 'count',
 14: 'other',
 15: 'letter',
 16: 'religion',
 17: 'food',
 18: 'country',
 19: 'color',
 20: 'termeq',
 21: 'body',
 22: 'dismed',
 23: 'mount',
 24: 'money',
 25: 'product',
 26: 'period',
 27: 'substance',
 28: 'city',
 29: 'sport',
 30: 'plant',
 31: 'techmeth',
 32: 'volsize',
 33: 'instru',
 34: 'abb',
 35: 'speed',
 36: 'word',
 37: 'lang',
 38: 'perc',
 39: 'code',
 40: 'dist',
 41: 'temp',
 42: 'symbol',
 43: 'ord',
 44: 'veh',
 45: 'weight',
 46: 'currency'}