In [1]:
import os
import pandas as pd
import numpy as np
import pickle
import shutil
from datetime import datetime, timedelta
import yfinance as yf
import holidays

# tqdm
import tqdm
from tqdm import tqdm_notebook, trange

# torch
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification, AdamW

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# locals
from config import Config
from input_example import InputExample
from input_features import InputFeatures, convert_example_to_feature
from twitter import Twitter

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# The maximum total input sequence length after WordPiece tokenization.
# Sequences longer than this will be truncated, and sequences shorter than this will be padded.
MAX_SEQ_LENGTH = 128

TRAIN_BATCH_SIZE = 24
EVAL_BATCH_SIZE = 8
LEARNING_RATE = 2e-5
NUM_TRAIN_EPOCHS = 1
RANDOM_SEED = 42
GRADIENT_ACCUMULATION_STEPS = 1
WARMUP_PROPORTION = 0.1
OUTPUT_MODE = 'classification'

CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"

In [3]:
class Predictor(object):
    
    def __init__(self):
        self.twitter = Twitter()
        self.config = Config()
        self.us_holidays = holidays.CountryHoliday('US')
    
    def set_stock(self, stock):
        self.stock = stock
    
    def set_date_range(self, from_date, to_date):
        self.to_date = datetime.strptime(to_date, "%Y-%m-%d")
        self.from_date = datetime.strptime(from_date, "%Y-%m-%d")
        
    def set_model(self, path = None):
        self.model = BertForSequenceClassification.from_pretrained(
            'bert-base-cased', cache_dir=self.config.get_cache_dir(), num_labels=2
        )
        
        if path != None:
            self.model.load_state_dict(torch.load(path))
            
        self.model.to(device)
        
    def _conv_eval_data(self, tweets, max_seq_length = 48):
        features = self.twitter.conv2features(tweets, max_seq_length = max_seq_length)
        
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
        
        if OUTPUT_MODE == "classification":
            all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        elif OUTPUT_MODE == "regression":
            all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)
            
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        return eval_data
        
    def load_data(self, query, from_date, to_date, count = 16):
        tweets = self.twitter.get_online_tweets(query, from_date, to_date, count)
        print([x.text_a for x in tweets])
        return self._conv_eval_data(tweets)
    
    def load_test_data(self, filename, from_date, to_date):
        tweets = self.twitter.get_offline_tweets(filename, from_date, to_date)
        return self._conv_eval_data(tweets)
        
    def load_custom_data(self):
        text = [
            "very good, the price must be rising, everyone must buy now.",
            "unbeliveable good, this price is reasonable and it must go up next month",
            "NO, please don't believe it, must go down later, fake!",
            "very bad, impossble low price is coming, sell them add, dangerous!"
        ]
        
        tweets = [InputExample(i, t, label=0) for i, t in enumerate(text)]
        print([x.text_a for x in tweets])
        return self._conv_eval_data(tweets)
        
    def classify_tweets(self, eval_data, threshold = 0.5):
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=EVAL_BATCH_SIZE)
        
        # Start predicts stage
        self.model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []
        bert_predicted = []
        
        for input_ids, input_mask, segment_ids, label_ids in tqdm_notebook(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                output = self.model(input_ids, segment_ids, input_mask, labels=None)
                
            logits = output[0].detach().cpu().numpy()
            
            if len(preds) == 0:
                preds.append(logits)
            else:
                preds[0] = np.append(preds[0], logits, axis=0)
                
            bert_predicted += list(logits[:, 0] > threshold)
            
        preds = preds[0]
        # print(preds)
        return bert_predicted
    
    def _get_market_days(self, from_date, to_date):
        all_days = pd.bdate_range(from_date, to_date).tolist()
        market_days = list(filter(lambda x: not(x in self.us_holidays), all_days))
        return market_days
    
    def evaluate(self, filename):
        market_days = self._get_market_days(self.from_date, self.to_date)
        pred = []
        ref_days = dict()
        # df = pd.DataFrame(columns=['Date', 'Pred'])
        for i, day in enumerate(market_days):
            if i == 0:
                d = day - timedelta(1)
            else:
                d = market_days[i-1]
            ref_days[day] = d
        
        for i, day in enumerate(market_days):
            eval_data = predictor.load_test_data(filename, ref_days[day], day)
            result = predictor.classify_tweets(eval_data)
            value = 1 if result.count(True) > result.count(False) else 0 if result.count(True) == result.count(False) else -1
            pred = np.append(pred, value)
            # df.loc[i] = {'Date': day, 'Pred': value}
        
        return market_days, pred
        # return df
    
    def predicts(self, query, count = 16):
        market_days = self._get_market_days(self.from_date, self.to_date)
        pred = []
        ref_days = dict()
        # df = pd.DataFrame(columns=['Date', 'Pred'])
        for i, day in enumerate(market_days):
            if i == 0:
                d = day - timedelta(1)
            else:
                d = market_days[i-1]
            ref_days[day] = d
        
        for i, day in enumerate(market_days):
            # self.set_date_range(.strftime("%Y-%m-%d"), day.strftime("%Y-%m-%d"))
            eval_data = predictor.load_data(query=query, from_date = ref_days[day], to_date = day, count=count)
            result = predictor.classify_tweets(eval_data)
            value = 1 if result.count(True) > result.count(False) else 0 if result.count(True) == result.count(False) else -1
            pred = np.append(pred, value)
            # df.loc[i] = {'Date': day.strftime("%Y-%m-%d"), 'Pred': value}
                        
        return market_days, pred
        # return df

    def verify(self, market_days, pred):
        ticker = yf.Ticker(self.stock)
        hist = ticker.history(start=self.from_date.strftime("%Y-%m-%d"), end=self.to_date.strftime("%Y-%m-%d"))
        hist = hist.reset_index()
        hist = hist.loc[hist["Date"].isin(market_days)]
        # print(hist)
        
        diff = hist["Close"] - hist["Open"]
        verify = [1 if p > 0 else 0 if p == 0 else -1 for p in diff]
        
        #if len(pred) != len(verify):
        #    raise Exception("length of pred is {}, does not match with length of verify {}".format(len(pred), len(verify)))
        
        result = []
        for p, v in zip(pred, verify):
            result = np.append(result, p == v)
                
        correct = (result == True).sum()
        accuracy = correct / len(pred)
        
        return result, accuracy
    
    def output(self, market_days, pred, filename = "result.js"):
        ticker = yf.Ticker(self.stock)
        hist = ticker.history(start=self.from_date.strftime("%Y-%m-%d"), end=self.to_date.strftime("%Y-%m-%d"))
        hist = hist.reset_index()
        hist = hist.loc[hist["Date"].isin(market_days)]
        
        date = hist["Date"].tolist()
        close = hist["Close"].tolist()
        df = pd.DataFrame(columns=['date', 'value', 'state'])
        
        for i, (d, c, p) in enumerate(zip(date, close, pred)):
            df.loc[i] = {'date': d.strftime("%Y-%m-%d"), 'value': c, 'state': int(p)}
            
        with open(filename, "w") as file:
            file.write("var my_data = [")
            for index, row in df.iterrows():
                file.write('{"date":"' + row['date'] + '","value":' + str(row['value']) + ',"state":' + str(row['state']) + '},')
            file.write("];")
            
        return df

In [4]:
predictor = Predictor()
# predictor.set_model()
predictor.set_model("../output/appl_100/pytorch_model.bin")

Retrieve twitter credential: ../twitter-cred\credentials.txt


Using cache found in C:\Users\ivangundampc/.cache\torch\hub\huggingface_pytorch-transformers_master
INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at C:\Users\ivangundampc\.cache\torch\transformers\5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at ../cache/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.d7a3af18ce3a2ab7c0f48f04dc8daff45ed9a3ed333b9e9a79d012a0dedf87a6
INFO:transformers.configuration_utils:Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  

In [5]:
predictor.set_stock("AAPL")
#predictor.set_date_range("2019-12-03", "2019-12-10")
#market_days, pred = predictor.predicts(query = "$AAPL", count = 20)
predictor.set_date_range("2015-01-01", "2015-03-30")
market_days, pred = predictor.evaluate("aapl.tsv")

HBox(children=(IntProgress(value=0, description='Evaluating', max=9, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=13, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=9, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=10, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=9, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=6, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=4, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=5, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=5, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=3, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=6, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=17, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=31, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=23, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=11, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=10, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=6, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=13, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=15, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=15, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=12, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=17, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=14, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=6, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=6, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=12, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=14, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=9, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=10, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=10, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=11, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=11, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=14, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=21, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=20, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=14, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=11, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=11, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=8, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=9, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=12, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=11, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=7, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=5, style=ProgressStyle(description_width='in…




HBox(children=(IntProgress(value=0, description='Evaluating', max=12, style=ProgressStyle(description_width='i…




In [6]:
result, accuracy = predictor.verify(market_days, pred)
print("Result: {}, accuracy: {}".format(result, accuracy))

INFO:numexpr.utils:Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


Result: [1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 0. 0. 1. 1. 1. 0. 1. 0. 1. 1. 1. 0. 1.
 1. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0.
 1. 1. 1. 1. 0. 0. 0. 1. 1. 0. 0.], accuracy: 0.5333333333333333


In [7]:
df = predictor.output(market_days, pred)