In [None]:
import torch
import numpy as np
from transformers import BertConfig, BertTokenizerFast
import sys
sys.path.append('../')
from utils.model import BertForBilevelClassification
from run_backtest import _get_positive_for_event_single


MODEL_DIR = '../models/model_seed24'
MAX_LEN = 256
index2event = {
    '0': 'Acquisitions',
    '1': 'Clinical Trials',
    '2': 'Dividend Cut',
    '3': 'Dividend Increase',
    '4': 'Guidance Change',
    '5': 'New Contract',
    '6': 'Regular Dividend',
    '7': 'Reverse Stock Split',
    '8': 'Special Dividend',
    '9': 'Stock Repurchase',
    '10': 'Stock Split',
    '11': 'NoEvent',
}
event2index = {v: k for k, v in index2event.items()}
NOEVENT_ID = int(event2index['NoEvent'])

IS_POSITIVE = {
    'Acquisitions': True,
    'Clinical Trials': True,
    'Dividend Cut': False,
    'Dividend Increase': True,
    'Guidance Change': True,
    'New Contract': True,
    'Regular Dividend': True,
    'Reverse Stock Split': False,
    'Special Dividend': True,
    'Stock Repurchase': True,
    'Stock Split': True,
    'Sentiment': True,
}

#Model Configuration
config = BertConfig.from_pretrained(MODEL_DIR)
config.num_labels = 12
config.max_seq_length = MAX_LEN
tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR)
model = BertForBilevelClassification.from_pretrained(MODEL_DIR, config=config)
model.eval()

#Function to account for NO EVENT TAGS
def _get_positive_for_event_single(pred):
    pred[pred == -100] = NOEVENT_ID
    tags = set(pred)
    results = {}
    if len(tags) > 1:
        tags.remove(NOEVENT_ID)
        for tag in list(tags):
            results[index2event[str(tag)]] = np.where(pred==tag)[0]
    return results

def event_detector(article):
        with torch.no_grad():
            #Article length must be greater than roughly 256 words
            model_input = tokenizer.encode_plus(article, add_special_tokens=True, max_length=256, truncation=True, padding=True)["input_ids"]
            #Convert into tensor
            model_input = torch.tensor(model_input, dtype=torch.long).unsqueeze(0)
            #Outputs what a given token labels might be on some kind of -n to n scale
            output = model(model_input)[0].squeeze(0).cpu().numpy()
            #Prints the possible event in integer form of the detected tokens
            pred = np.argmax(output, axis=1)
            #Converts those integers in pred into the corresponding events
            results = _get_positive_for_event_single(pred)
            
            #If an event is detected it will return a dict where the key is the event detected and the value is a list of the index of detected tokens
            if len(results) > 0:
                return results
            else:
                return False