In [1]:
import json
import csv
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from tqdm import tqdm_notebook
import torch.optim as optim
import torchtext
from torchtext.data import Field, BucketIterator, Iterator, TabularDataset

## Load Data

In [2]:
with open('../../data/processed/splits/train/stock_data.json', 'r') as inp:
    train_stock = json.load(inp)
with open('../../data/processed/splits/valid/stock_data.json', 'r') as inp:
    valid_stock = json.load(inp)
with open('../../data/processed/splits/test/stock_data.json', 'r') as inp:
    test_stock = json.load(inp)

In [3]:
ID = Field(
    sequential=False
)
TRANSCRIPT = Field(
    sequential=True,
    fix_length=11000,
    lower=True
)
LABEL = Field(
    sequential=False
)

In [19]:
class StockDataset(data.Dataset):
    def __init__(self, examples):
        examples = np.array(examples)
        self.labels = examples[:,-1]
        self.market_cap = examples[:,-2]
        self.examples = np.array(examples[:,:-2].tolist())

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.examples)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Load data and get label
        X = torch.tensor(self.examples[index])
        auxiliary = self.market_cap[index]
        y = self.labels[index]

        return X, auxiliary, y

In [16]:
def build_datasets():
    train, valid, test = TabularDataset.splits(
        path='../../data/processed/splits',
        format='csv',
        skip_header=True,
        train='train/transcripts.csv',
        validation='valid/transcripts.csv',
        test='test/transcripts.csv',
        fields=[('id', ID), ('transcript', TRANSCRIPT), ('post_high', LABEL)]
    )
    glove = torchtext.vocab.GloVe(name='6B', dim=50)
    TRANSCRIPT.build_vocab(train, valid, test, vectors=glove)
    return train, valid, test

In [17]:
train, valid, test = build_datasets()

.vector_cache/glove.6B.zip: 862MB [06:26, 2.23MB/s]                               
100%|█████████▉| 398578/400000 [00:30<00:00, 26768.09it/s]

In [23]:
print(train[0].transcript)

["['_',", "'s',", "'p',", "'_',", "'s',", "'t',", "'a',", "'r',", "'t',", "'_',", "'o',", "'p',", "'e',", "'r',", "'a',", "'t',", "'o',", "'r',", "'_',", "'s',", "'p',", "'_',", "'e',", "'n',", "'d',", "'_',", "'g',", "'o',", "'o',", "'d',", "'_masked_',", "'o',", "'r',", "'n',", "'i',", "'n',", "'g',", "'a',", "'n',", "'d',", "'w',", "'e',", "'l',", "'c',", "'o',", "'_masked_',", "'e',", "'t',", "'o',", "'t',", "'h',", "'e',", "'o',", "'r',", "'e',", "'i',", "'l',", "'l',", "'y',", "'a',", "'u',", "'t',", "'o',", "'_masked_',", "'o',", "'t',", "'i',", "'v',", "'e',", "'i',", "'n',", "'c',", "'s',", "'e',", "'c',", "'o',", "'n',", "'d',", "'q',", "'u',", "'a',", "'r',", "'t',", "'e',", "'r',", "'_',", "'m',", "'a',", "'s',", "'k',", "'e',", "'d',", "'_',", "'e',", "'a',", "'r',", "'n',", "'i',", "'n',", "'g',", "'s',", "'c',", "'o',", "'n',", "'f',", "'e',", "'r',", "'e',", "'n',", "'c',", "'e',", "'c',", "'a',", "'l',", "'l',", "'_masked_',", "'y',", "'n',", "'a',", "'_masked_',", "'e