In [None]:
import os
import pandas as pd
import numpy as np
import pickle
import shutil
from datetime import datetime

# tqdm
import tqdm
from tqdm import tqdm_notebook, trange

# torch
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification, AdamW

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# locals
from config import Config
from input_example import InputExample
from input_features import InputFeatures, convert_example_to_feature
from twitter import Twitter

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# The maximum total input sequence length after WordPiece tokenization.
# Sequences longer than this will be truncated, and sequences shorter than this will be padded.
MAX_SEQ_LENGTH = 128

TRAIN_BATCH_SIZE = 24
EVAL_BATCH_SIZE = 8
LEARNING_RATE = 2e-5
NUM_TRAIN_EPOCHS = 1
RANDOM_SEED = 42
GRADIENT_ACCUMULATION_STEPS = 1
WARMUP_PROPORTION = 0.1
OUTPUT_MODE = 'classification'

CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"

In [None]:
class Predictor(object):
    
    def __init__(self):
        self.twitter = Twitter()
        self.config = Config()
    
    def set_stock(self, stock):
        self.stock = stock
    
    def set_date_range(self, from_date, to_date):
        self.to_date = datetime.strptime(to_date, "%Y-%m-%d")
        self.from_date = datetime.strptime(from_date, "%Y-%m-%d")
        
    def set_model(self, filename = None):
        if filename == None:
            self.model = BertForSequenceClassification.from_pretrained(
                'bert-base-cased', cache_dir=self.config.get_cache_dir(), num_labels=2
            )
        else:
            self.model = BertForSequenceClassification.from_pretrained(
                os.path.join(self.config.get_cache_dir(), filename), cache_dir=self.config.get_cache_dir(), num_labels=2
            )
            
        self.model.to(device)
        
    def load_data(self, query = None, count = 10):
        if query == None:
            query = f"${self.stock}"
            
        tweets = self.twitter.get_online_tweets(query, self.from_date.strftime("%Y-%m-%d"), self.to_date.strftime("%Y-%m-%d"), count)
        features = self.twitter.conv2features(tweets)
        
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
        
        if OUTPUT_MODE == "classification":
            all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        elif OUTPUT_MODE == "regression":
            all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)
            
        self.eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        
    def classify_tweets(self):
        # Run prediction for full data
        eval_sampler = SequentialSampler(self.eval_data)
        eval_dataloader = DataLoader(self.eval_data, sampler=eval_sampler, batch_size=EVAL_BATCH_SIZE)
        
        # Start predicts stage
        self.model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []
        
        for input_ids, input_mask, segment_ids, label_ids in tqdm_notebook(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                output = self.model(input_ids, segment_ids, input_mask, labels=None)
                
            logits = output[0].detach().cpu().numpy()
            
            if len(preds) == 0:
                preds.append(logits)
            else:
                preds[0] = np.append(preds[0], logits, axis=0)
                
        preds = preds[0]
        return preds

In [None]:
predictor = Predictor()
predictor.set_model()

In [None]:
predictor.set_stock("AAPL")
predictor.set_date_range("2019-12-06", "2019-12-07")
predictor.load_data(count = 50)
print(predictor.classify_tweets())