In [1]:
import os
import pandas as pd
import numpy as np
import pickle
import shutil
from datetime import datetime

# tqdm
import tqdm
from tqdm import tqdm_notebook, trange

# torch
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification, AdamW

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# locals
from config import Config
from input_example import InputExample
from input_features import InputFeatures, convert_example_to_feature
from twitter import Twitter

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# The maximum total input sequence length after WordPiece tokenization.
# Sequences longer than this will be truncated, and sequences shorter than this will be padded.
MAX_SEQ_LENGTH = 128

TRAIN_BATCH_SIZE = 24
EVAL_BATCH_SIZE = 8
LEARNING_RATE = 2e-5
NUM_TRAIN_EPOCHS = 1
RANDOM_SEED = 42
GRADIENT_ACCUMULATION_STEPS = 1
WARMUP_PROPORTION = 0.1
OUTPUT_MODE = 'classification'

CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"

In [99]:
class Predictor(object):
    
    def __init__(self):
        self.twitter = Twitter()
        self.config = Config()
    
    def set_stock(self, stock):
        self.stock = stock
    
    def set_date_range(self, from_date, to_date):
        self.to_date = datetime.strptime(to_date, "%Y-%m-%d")
        self.from_date = datetime.strptime(from_date, "%Y-%m-%d")
        
    def set_model(self, filename = None):
        if filename == None:
            self.model = BertForSequenceClassification.from_pretrained(
                'bert-base-cased', cache_dir=self.config.get_cache_dir(), num_labels=2
            )
        else:
            self.model = BertForSequenceClassification.from_pretrained(
                os.path.join(self.config.get_cache_dir(), filename), cache_dir=self.config.get_cache_dir(), num_labels=2
            )
            
        self.model.to(device)
        
    def load_data(self, query = None, count = 10):
        if query == None:
            query = f"${self.stock}"
            
        tweets = self.twitter.get_online_tweets(query, self.from_date.strftime("%Y-%m-%d"), self.to_date.strftime("%Y-%m-%d"), count)
        features = self.twitter.conv2features(tweets)
        
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
        
        if OUTPUT_MODE == "classification":
            all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        elif OUTPUT_MODE == "regression":
            all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)
            
        self.eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        
    def classify_tweets(self):
        # Run prediction for full data
        eval_sampler = SequentialSampler(self.eval_data)
        eval_dataloader = DataLoader(self.eval_data, sampler=eval_sampler, batch_size=EVAL_BATCH_SIZE)
        
        # Start predicts stage
        self.model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []
        
        for input_ids, input_mask, segment_ids, label_ids in tqdm_notebook(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                output = self.model(input_ids, segment_ids, input_mask, labels=None)
                
            logits = output[0].detach().cpu().numpy()
            
            if len(preds) == 0:
                preds.append(logits)
            else:
                preds[0] = np.append(preds[0], logits, axis=0)
                
        preds = preds[0]
        return preds

In [100]:
predictor = Predictor()
predictor.set_model()

Retrieve twitter credential: ../twitter-cred\credentials.txt


Using cache found in C:\Users\ivangundampc/.cache\torch\hub\huggingface_pytorch-transformers_master
INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at C:\Users\ivangundampc\.cache\torch\transformers\5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at ../cache/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.d7a3af18ce3a2ab7c0f48f04dc8daff45ed9a3ed333b9e9a79d012a0dedf87a6
INFO:transformers.configuration_utils:Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  

In [101]:
predictor.set_stock("AAPL")
predictor.set_date_range("2019-12-06", "2019-12-07")
predictor.load_data(count = 50)
print(predictor.classify_tweets())

HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…


[[-0.12702045  0.2871619 ]
 [ 0.12181091  0.34446043]
 [-0.04136693  0.21541844]
 [ 0.06951762  0.31560788]
 [ 0.06356005  0.32711464]
 [-0.04136693  0.21541844]
 [-0.04136693  0.21541844]
 [-0.04136693  0.21541844]
 [-0.04136693  0.21541844]
 [-0.00856753  0.21816802]
 [ 0.01549479  0.26508442]
 [-0.01316893  0.2843372 ]
 [-0.00220382  0.16379139]
 [ 0.07010119  0.2905017 ]
 [-0.14275637  0.27899   ]
 [ 0.03465731  0.27055955]
 [ 0.07325664  0.3143697 ]
 [-0.06785884  0.25928167]
 [ 0.01559635  0.30415982]
 [ 0.09444959  0.3143896 ]
 [-0.01773433  0.24542034]
 [ 0.18109378  0.2981025 ]
 [-0.00061217  0.29476646]
 [-0.0136187   0.29402113]
 [-0.03233413  0.25972164]
 [ 0.08239074  0.1640906 ]
 [ 0.13346691  0.20968746]
 [-0.04789746  0.3101174 ]
 [-0.061235    0.18748933]
 [ 0.21715312  0.3103689 ]
 [ 0.08565958  0.30528954]
 [-0.05189344  0.2944364 ]
 [ 0.30211806  0.26764444]
 [ 0.09242627  0.3356067 ]
 [-0.00258614  0.1888335 ]
 [ 0.02731377  0.24190482]
 [-0.0336304   0.1195385 ]
