## Part 1: Import "all news" data and NASDAQ index, package together in numpy array

The idea: we don't have a lot of targeted stock data, unfortunately. So the thought is that we might be able to take a large variety of news article titles (from 2016-2018) and use them to try to predict the oveall NASDAQ index.

From the set of all articles published at each day, we will split them into groups of 5 or 6, and join the strings to form our X. Our y is the nasdaq price movement on that day. This gives us around 7500 data points. 

Note: this dataset should have a very high variance. Potentially, we could start off the learning with this dataset, and then fine tune it to each particular stock? We'll have to play around

In [6]:
'''
    Checking out this dataset: https://www.kaggle.com/snapcrack/all-the-news#articles1.csv
'''
import pandas as pd
import csv


datadir = '../crawl_data/all_the_news_data/all-the-news/'
datafile = datadir + 'articles1.csv'

df1 = pd.read_csv(datafile)

df1.head()

Unnamed: 0.1,Unnamed: 0,id,title,publication,author,date,year,month,url,content
0,0,17283,House Republicans Fret About Winning Their Hea...,New York Times,Carl Hulse,2016-12-31,2016.0,12.0,,WASHINGTON — Congressional Republicans have...
1,1,17284,Rift Between Officers and Residents as Killing...,New York Times,Benjamin Mueller and Al Baker,2017-06-19,2017.0,6.0,,"After the bullet shells get counted, the blood..."
2,2,17285,"Tyrus Wong, ‘Bambi’ Artist Thwarted by Racial ...",New York Times,Margalit Fox,2017-01-06,2017.0,1.0,,"When Walt Disney’s “Bambi” opened in 1942, cri..."
3,3,17286,"Among Deaths in 2016, a Heavy Toll in Pop Musi...",New York Times,William McDonald,2017-04-10,2017.0,4.0,,"Death may be the great equalizer, but it isn’t..."
4,4,17287,Kim Jong-un Says North Korea Is Preparing to T...,New York Times,Choe Sang-Hun,2017-01-02,2017.0,1.0,,"SEOUL, South Korea — North Korea’s leader, ..."


In [27]:
#get rid of the na's...
'''df2 = df1.where(df1['date'] == '2016-12-31')
df2.dropna()
df2.head()'''

#investigate more about this data
df2 = df1.where(df1['publication'] == 'New York Times')
#df2.head()
df2 = df2[df2['date'].notna()]
print(df2.shape)


(7803, 10)


In [28]:
#df2['date'].sort_values()

df2[df2['date'] == '2016-04-18']

Unnamed: 0.1,Unnamed: 0,id,title,publication,author,date,year,month,url,content
2534,2534.0,20114.0,Yahoo’s Suitors Uncover Few Financial Details ...,New York Times,Vindu Goel and Michael J. de la Merced,2016-04-18,2016.0,4.0,,SAN FRANCISCO — As Yahoo prepares to accept...
2535,2535.0,20115.0,Dilma Rousseff Targeted in Brazil by Lawmakers...,New York Times,Simon Romero and Vinod Sreeharsha,2016-04-18,2016.0,4.0,,"BRASÍLIA — Paulo Maluf, a Brazilian congres..."
2540,2540.0,20120.0,Chinese General Visits Disputed Spratly Island...,New York Times,Chris Buckley,2016-04-18,2016.0,4.0,,BEIJING — China’s most senior uniformed mil...
2560,2560.0,20141.0,"On Crime Bill and the Clintons, Young Blacks C...",New York Times,Farah Stockman,2016-04-18,2016.0,4.0,,"Rufus Farmer, 33, was tired of all the ways he..."
2562,2562.0,20143.0,Brazil’s Lower House of Congress Votes for Imp...,New York Times,Andrew Jacobs,2016-04-18,2016.0,4.0,,BRASÍLIA — Brazilian legislators voted on S...
2565,2565.0,20146.0,Scenes of Ruin After Ecuador Earthquake Kills ...,New York Times,Maggy Ayala and Nicholas Casey,2016-04-18,2016.0,4.0,,"QUITO, Ecuador — The strongest earthquake t..."
2568,2568.0,20149.0,What Should We Expect From the Supreme Court’s...,New York Times,Emily Bazelon and Eric Posner,2016-04-18,2016.0,4.0,,"This morning, the United States Supreme Court ..."
2569,2569.0,20150.0,"Your Monday Evening Briefing: Ecuador, Immigra...",New York Times,Jonah Bromwich and Sandra Stevenson,2016-04-18,2016.0,4.0,,(Want to get this briefing by email? Here’s th...
5206,5206.0,23090.0,"Your Monday Briefing: Bernie Sanders, Dilma Ro...",New York Times,Adeel Hassan,2016-04-18,2016.0,4.0,,(Want to get this briefing by email? Here’s t...
5209,5209.0,23093.0,Spare a Swipe? New York City Eases Rules for a...,New York Times,Joseph Goldstein,2016-04-18,2016.0,4.0,,It is New York’s own version of hitchhiking: a...


In [None]:
'''
    Essentially, the same problem exists everywhere. There is only enough data for around a year or two.
    So how do I create enough training examples to train a GRU?
    
    * one idea: try to predict the NASDAQ or the DOW, using groups of 5 articles published on the same day.
    * use just the New York Times, or just Breitbart, or something.
'''

# I want to do this as quickly as possible - it's only a one time thing after all...



In [60]:
'''
    In this cell, pull in the stock market data (just the nasdaq composite index) from csv files
'''

from datetime import datetime as dt

#here, we pass in the data directory and filename
def extract_historical_stock(stocks_data_dir, sfile):
    stocks_files = os.listdir(stocks_data_dir)

    #for each file, make the file into a pandas dataframe.
    df = pd.DataFrame()
    
    #get only the filename we're looking for.
    r1 = df.copy()

    full_path = stocks_data_dir + '/' + sfile
    stockdf1 = pd.read_csv(full_path)
    sfile_tkr = sfile[:-4]
    # print(stockdf1.head())

    #convert dates to the right format
    dates = stockdf1['Date'].apply(lambda x: dt.strptime(x,'%m/%d/%Y').strftime('%Y-%m-%d')).rename('date')
    #require a space (' ') before the key name... Also column names: "stockdf.columns"
    #here we convert string stock price to float.
    close_prc = stockdf1[' Close/Last'].apply(lambda x: float(x)).rename('Close')
    open_prc = stockdf1[' Open'].apply(lambda x: float(x)).rename('Open')

    r1 = pd.concat([dates, open_prc, close_prc], axis=1)
    r1['ticker'] = sfile_tkr
    #all_tkr_dfs.append(r1)
    #print(r1)

    #[not needed]: combine everything at the end.
    #historical_stock_prices = pd.concat(all_tkr_dfs, axis=0).reset_index(drop=True)
    
    return r1


stocks_data_dir = '../crawl_data/stock_data'
sfile = 'nasdaq_10yr.csv'
historical_stock_prices = extract_historical_stock(stocks_data_dir, sfile)
print(historical_stock_prices[historical_stock_prices.ticker=='nasdaq_10yr'].head())

         date     Open    Close       ticker
0  2020-03-18  6902.32  6989.84  nasdaq_10yr
1  2020-03-17  7072.00  7334.78  nasdaq_10yr
2  2020-03-16  7392.73  6904.59  nasdaq_10yr
3  2020-03-13  7610.39  7874.88  nasdaq_10yr
4  2020-03-12  7398.58  7201.80  nasdaq_10yr


In [75]:
'''
    next, let's create our dataset. For each date in our data, let's try to pull combinations of 3-4 news articles
    that are from that date.
    And let's match those up with the actual stock price movements.
    
    Question: how should I execute the training?
        1. combine all headlines into a single sequence
        2. Recurrently cycle through sequences:
            i. Run the GRU on every possible sequence, obtaining n different outputs (one per title)
            ii. push all n outputs into a final MLP.
            ~ has this ever been done beforee? You should potentially comb through the literature
            
'''

#first, let's grab the dates.
art_dates = df1['date'].unique()
hsp_dates = historical_stock_prices['date']
s1 = set(hsp_dates) & set(art_dates)
datelist = list(s1)

print("total original article dates: {}".format(len(art_dates)))
print("article dates, stock movement known: {}".format(len(datelist)))


total original article dates: 983
article dates, stock movement known: 721


In [84]:
#next, grab all titles for a given date.
import numpy as np

k = 5 #the total number of titles for every "data point"

x_titles = []
for date in datelist:
    print(date)
    #obtain all articles, print out how many there are.
    dx_data = df1.where(df1['date'] == date)
    dx_data = dx_data[dx_data['title'].notna()]
    titles = np.array(dx_data['title'])
    print(titles.shape)
    
    #obtain the y for this date
    daterw = historical_stock_prices.where(historical_stock_prices['date'] == date)
    daterw = daterw[daterw['date'].notna()]
    oprice, cprice = float(daterw['Open']), float(daterw['Close'])
    y = -1 if oprice > cprice else 1
    print(daterw)
    print(y)
    
    #now, create a random shuffled list of k titles.   
    t = len(titles)
    if t < k:
        x_titles.append(np.array(['. \n '.join(titles), y]))
    else:
        size = int(len(titles)/k)
        for i in range(size):
            indices = np.random.randint(0, t, k) #k-sized list of ints less than t
            t_list = titles[indices]
            x_titles.append(np.array(['. \n '.join(t_list), y]))
    
    #create a random data point
    #print(x_titles[:3])
    #break

2016-06-28
(106,)
           date     Open    Close       ticker
936  2016-06-28  4643.93  4691.87  nasdaq_10yr
1
2016-05-25
(64,)
           date     Open    Close       ticker
959  2016-05-25  4877.18  4894.89  nasdaq_10yr
1
2016-02-19
(67,)
            date     Open    Close       ticker
1026  2016-02-19  4464.67  4504.43  nasdaq_10yr
1
2012-12-26
(4,)
            date    Open    Close       ticker
1818  2012-12-26  3012.6  2990.16  nasdaq_10yr
-1
2017-06-20
(69,)
           date     Open    Close       ticker
690  2017-06-20  6229.62  6188.03  nasdaq_10yr
-1
2016-12-22
(97,)
           date     Open    Close       ticker
812  2016-12-22  5472.01  5447.42  nasdaq_10yr
-1
2016-07-18
(94,)
           date     Open    Close       ticker
923  2016-07-18  5034.99  5055.78  nasdaq_10yr
1
2016-05-03
(63,)
           date     Open    Close       ticker
975  2016-05-03  4780.88  4763.22  nasdaq_10yr
-1
2017-01-20
(143,)
           date     Open    Close       ticker
794  2017-01-20  5556.87 

(6,)
            date    Open    Close       ticker
1290  2015-02-02  4650.6  4676.69  nasdaq_10yr
1
2015-06-24
(11,)
            date     Open    Close       ticker
1191  2015-06-24  5151.38  5122.41  nasdaq_10yr
-1
2013-08-15
(1,)
            date     Open    Close       ticker
1658  2013-08-15  3669.27  3606.12  nasdaq_10yr
-1
2016-12-20
(93,)
           date     Open    Close       ticker
814  2016-12-20  5473.53  5483.94  nasdaq_10yr
1
2013-04-03
(2,)
            date     Open   Close       ticker
1752  2013-04-03  3254.86  3218.6  nasdaq_10yr
-1
2015-03-18
(23,)
            date     Open    Close       ticker
1259  2015-03-18  4926.83  4982.83  nasdaq_10yr
1
2015-06-16
(28,)
            date     Open    Close       ticker
1197  2015-06-16  5023.58  5055.55  nasdaq_10yr
1
2015-01-06
(3,)
            date     Open    Close       ticker
1308  2015-01-06  4666.85  4592.74  nasdaq_10yr
-1
2013-11-20
(1,)
            date     Open    Close       ticker
1590  2013-11-20  3940.98  3921.2

(1,)
            date     Open    Close       ticker
1860  2012-10-23  2989.44  2990.46  nasdaq_10yr
1
2016-01-29
(59,)
            date     Open    Close       ticker
1040  2016-01-29  4512.09  4613.95  nasdaq_10yr
1
2015-04-01
(19,)
            date     Open    Close       ticker
1249  2015-04-01  4894.36  4880.23  nasdaq_10yr
-1
2016-04-19
(77,)
           date    Open    Close       ticker
985  2016-04-19  4968.3  4940.33  nasdaq_10yr
-1
2015-05-28
(34,)
            date     Open    Close       ticker
1210  2015-05-28  5096.34  5097.98  nasdaq_10yr
1
2016-08-08
(97,)
           date     Open    Close       ticker
908  2016-08-08  5223.54  5213.14  nasdaq_10yr
-1
2016-09-15
(120,)
           date    Open    Close       ticker
881  2016-09-15  5178.1  5249.69  nasdaq_10yr
1
2015-06-09
(27,)
            date     Open    Close       ticker
1202  2015-06-09  5013.13  5013.87  nasdaq_10yr
1
2015-07-10
(12,)
            date     Open   Close       ticker
1180  2015-07-10  4981.24  4997.7 

(1,)
            date     Open    Close       ticker
1737  2013-04-24  3269.33  3269.65  nasdaq_10yr
1
2015-12-17
(1,)
            date     Open    Close       ticker
1068  2015-12-17  5087.17  5002.55  nasdaq_10yr
-1
2017-02-17
(123,)
           date     Open    Close       ticker
774  2017-02-17  5807.31  5838.58  nasdaq_10yr
1
2015-11-16
(2,)
            date     Open    Close       ticker
1090  2015-11-16  4916.14  4984.62  nasdaq_10yr
1
2015-06-26
(38,)
            date     Open    Close       ticker
1189  2015-06-26  5113.26  5080.51  nasdaq_10yr
-1
2013-07-02
(1,)
            date     Open   Close       ticker
1689  2013-07-02  3434.49  3433.4  nasdaq_10yr
-1
2016-03-23
(60,)
            date     Open    Close       ticker
1003  2016-03-23  4813.87  4768.86  nasdaq_10yr
-1
2016-02-04
(74,)
            date     Open    Close       ticker
1036  2016-02-04  4492.48  4509.56  nasdaq_10yr
1
2017-01-11
(128,)
           date     Open    Close       ticker
800  2017-01-11  5550.72  556

(150,)
           date     Open    Close       ticker
802  2017-01-09  5527.58  5531.82  nasdaq_10yr
1
2015-03-05
(3,)
            date     Open    Close       ticker
1268  2015-03-05  4979.95  4982.81  nasdaq_10yr
1
2015-01-15
(23,)
            date     Open    Close       ticker
1301  2015-01-15  4657.46  4570.82  nasdaq_10yr
-1
2013-02-20
(1,)
            date     Open    Close       ticker
1781  2013-02-20  3213.59  3164.41  nasdaq_10yr
-1
2015-07-27
(4,)
            date     Open    Close       ticker
1169  2015-07-27  5055.92  5039.78  nasdaq_10yr
-1
2017-05-22
(90,)
           date     Open    Close       ticker
710  2017-05-22  6098.25  6133.62  nasdaq_10yr
1
2016-09-16
(88,)
           date     Open    Close       ticker
880  2016-09-16  5238.71  5244.57  nasdaq_10yr
1
2016-04-21
(81,)
           date     Open    Close       ticker
983  2016-04-21  4949.12  4945.89  nasdaq_10yr
-1
2013-06-28
(3,)
            date     Open    Close       ticker
1691  2013-06-28  3401.86  3403.2

(2,)
            date     Open    Close       ticker
1307  2015-01-07  4626.84  4650.47  nasdaq_10yr
1
2016-12-15
(125,)
           date     Open    Close       ticker
817  2016-12-15  5443.51  5456.85  nasdaq_10yr
1
2013-07-05
(2,)
            date     Open    Close       ticker
1687  2013-07-05  3443.67  3479.38  nasdaq_10yr
1
2016-11-22
(108,)
           date     Open    Close       ticker
833  2016-11-22  5384.75  5386.35  nasdaq_10yr
1
2015-04-07
(24,)
            date    Open    Close       ticker
1246  2015-04-07  4917.5  4910.23  nasdaq_10yr
-1
2014-07-21
(1,)
            date    Open   Close       ticker
1425  2014-07-21  4421.2  4424.7  nasdaq_10yr
1
2016-01-20
(64,)
            date     Open    Close       ticker
1047  2016-01-20  4405.22  4471.69  nasdaq_10yr
1
2017-01-19
(141,)
           date     Open    Close       ticker
795  2017-01-19  5560.61  5540.08  nasdaq_10yr
-1
2017-04-27
(112,)
           date     Open    Close       ticker
727  2017-04-27  6038.47  6048.94  n

(128,)
           date     Open    Close       ticker
875  2016-09-23  5327.43  5305.75  nasdaq_10yr
-1
2015-07-24
(10,)
            date     Open    Close       ticker
1170  2015-07-24  5166.91  5088.63  nasdaq_10yr
-1
2012-07-03
(1,)
            date     Open    Close       ticker
1938  2012-07-03  2950.81  2976.08  nasdaq_10yr
1
2014-01-07
(1,)
            date     Open    Close       ticker
1559  2014-01-07  4128.57  4153.18  nasdaq_10yr
1
2016-09-29
(98,)
           date     Open    Close       ticker
871  2016-09-29  5311.31  5269.15  nasdaq_10yr
-1
2015-03-17
(22,)
            date     Open    Close       ticker
1260  2015-03-17  4912.65  4937.43  nasdaq_10yr
1
2015-08-05
(2,)
            date     Open    Close       ticker
1162  2015-08-05  5132.77  5139.94  nasdaq_10yr
1
2017-03-28
(137,)
           date    Open    Close       ticker
748  2017-03-28  5836.5  5875.14  nasdaq_10yr
1
2016-11-04
(80,)
           date     Open    Close       ticker
845  2016-11-04  5034.41  5046.37

(86,)
           date     Open    Close       ticker
749  2017-03-27  5776.33  5840.37  nasdaq_10yr
1
2016-08-11
(114,)
           date     Open   Close       ticker
905  2016-08-11  5222.15  5228.4  nasdaq_10yr
1
2016-06-14
(94,)
           date     Open    Close       ticker
946  2016-06-14  4836.67  4843.55  nasdaq_10yr
1
2016-01-07
(61,)
            date    Open    Close       ticker
1055  2016-01-07  4736.4  4689.43  nasdaq_10yr
-1
2016-08-17
(66,)
           date     Open    Close       ticker
901  2016-08-17  5228.44  5228.66  nasdaq_10yr
1
2017-03-24
(109,)
           date     Open    Close       ticker
750  2017-03-24  5839.33  5828.74  nasdaq_10yr
-1
2013-09-13
(1,)
            date     Open    Close       ticker
1638  2013-09-13  3715.97  3722.18  nasdaq_10yr
1
2015-05-18
(6,)
            date     Open    Close       ticker
1217  2015-05-18  5040.92  5078.44  nasdaq_10yr
1
2016-06-15
(81,)
           date     Open    Close       ticker
945  2016-06-15  4855.08  4834.93  nasd

(74,)
           date     Open    Close       ticker
696  2017-06-12  6153.56  6175.46  nasdaq_10yr
1
2013-05-15
(1,)
            date     Open    Close       ticker
1722  2013-05-15  3462.61  3471.62  nasdaq_10yr
1
2014-11-13
(1,)
            date     Open    Close       ticker
1343  2014-11-13  4681.56  4680.14  nasdaq_10yr
-1
2013-01-03
(1,)
            date     Open    Close       ticker
1813  2013-01-03  3112.26  3100.57  nasdaq_10yr
-1
2016-03-03
(59,)
            date     Open    Close       ticker
1017  2016-03-03  4698.38  4707.42  nasdaq_10yr
1
2016-04-11
(59,)
           date     Open   Close       ticker
991  2016-04-11  4873.39  4833.4  nasdaq_10yr
-1
2015-03-27
(24,)
            date     Open    Close       ticker
1252  2015-03-27  4863.74  4891.22  nasdaq_10yr
1
2017-06-08
(99,)
           date     Open    Close       ticker
698  2017-06-08  6311.73  6321.76  nasdaq_10yr
1
2017-04-06
(153,)
           date     Open    Close       ticker
741  2017-04-06  5870.52  5878.95 

(85,)
           date     Open   Close       ticker
957  2016-05-27  4904.05  4933.5  nasdaq_10yr
1
2015-12-24
(1,)
            date     Open    Close       ticker
1063  2015-12-24  5046.19  5048.49  nasdaq_10yr
1
2016-03-17
(66,)
            date     Open    Close       ticker
1007  2016-03-17  4752.62  4774.99  nasdaq_10yr
1
2017-03-23
(102,)
           date     Open    Close       ticker
751  2017-03-23  5812.31  5817.69  nasdaq_10yr
1
2012-09-24
(1,)
            date     Open    Close       ticker
1881  2012-09-24  3155.35  3160.78  nasdaq_10yr
1
2017-06-01
(69,)
           date     Open    Close       ticker
703  2017-06-01  6215.91  6246.83  nasdaq_10yr
1
2016-03-21
(63,)
            date     Open    Close       ticker
1005  2016-03-21  4787.31  4808.87  nasdaq_10yr
1
2013-03-08
(1,)
            date     Open    Close       ticker
1769  2013-03-08  3232.09  3244.37  nasdaq_10yr
1
2013-05-24
(1,)
            date     Open    Close       ticker
1715  2013-05-24  3459.42  3459.14  n

(61,)
            date     Open    Close       ticker
1016  2016-03-04  4715.76  4717.02  nasdaq_10yr
1


In [81]:
#why does y never equal -1?
dt = '2016-04-15'

daterw = historical_stock_prices.where(historical_stock_prices['date'] == dt)
daterw = daterw[daterw[dt].notna()]
oprice, cprice = float(daterw['Open']), float(daterw['Open'])
y = -1 if oprice > cprice else 1

10


In [87]:
trdata = np.vstack(x_titles) #7.5k data points... this should work, although it's still on the lower end.
print(trdata.shape)
print(trdata[:3])

(7489, 2)
[['Benghazi Committee Releases Final Report, Slams Clinton. \n In American Markets, Panic Begins to Subside - The New York Times. \n Barack Obama: ‘Mr. Trump Embodies Global Elites’ Not Working Class - Breitbart. \n Reactions to the Supreme Court Ruling on Texas’ Abortion Law - The New York Times. \n Stacey Dash Talks ’Dash America’ Movement: Conservatives Can ’Take Back Influence and Power’ in Hollywood - Breitbart'
  '1']
 ['‘Brexit’ Opens Uncertain Chapter in Britain’s Storied History - The New York Times. \n Benghazi Committee Releases Final Report, Slams Clinton. \n The 15 Questions About Benghazi Barack Obama Does Not Want To Answer - Breitbart. \n Jenner on ’SI’ 40 Years Later: My ‘Macho Male’ Olympics Body Disguised ‘The Woman Living Inside Me’ - Breitbart. \n Sharyl Attkisson: Obama and Clinton Lied to the Public, ‘Impeded the Investigation’ - Breitbart'
  '1']
 ['Hamas Ends Week-Old Deal Importing Israeli Watermelons Into Gaza. \n Demand for ‘Himalayan Viagra’ Fungu

In [88]:
#save your training data in a pickle
import pickle as pkl

with open('./pandas_trdata/allstocks_trdata_8k.pkl', 'wb') as f:
    pkl.dump(trdata, f)


## Part 2: Start doing Machine Learning:
Let's start training a GRU to perform this task and see how well it does.

[NOT MUCH ML HERE, SEE "KAGGLE NEWS GRU" NOTEBOOK]

In [90]:
import torch
from torch import nn

'''
source: https://blog.floydhub.com/gru-with-pytorch/
pytorch documentation (not very good): https://pytorch.org/docs/master/nn.html#gru
'''


def create_emb_layer(num_embeddings, embedding_dim, emb_wts = None, non_trainable=False):
    if emb_wts is None:
        embedding = nn.Embedding(num_embeddings, embedding_dim)
    else:
        num_embeddings, embedding_dim = emb_wts.shape
        embedding = nn.Embedding(num_embeddings, embedding_dim)
        embedding.load_state_dict(emb_wts)
        if non_trainable:
            embedding.weight.requires_grad = False
    return embedding

#create a GRU class
class GRUNet(nn.Module):
    
    '''
        input parameters:
            input_dim: size of the input embeddings (~50?)
            hidden_dim: size of the hidden parameter (50)
            output_dim: size of the output (50?)
            n_layers_gru: number of stacked layers of the GRU (1 or 2) ~how to make bidirectional?
            
            embedding_dim: should be the same as input_dim
            emb_wts: matrix of pretrained embeddings (GloVe)
            num_embeddings: total vocab size...
    '''
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers_gru, drop_prob=0.2,
                num_embeddings = 0, embedding_dim = 0, emb_wts = None, non_trainable=False):
        super(GRUNet, self).__init__() #initialize the super-class?
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        #we need an embedding layer, a gru layer, and maybe a FC layer
        #create an embedding layer here.
        self.embedding = create_emb_layer(num_embeddings, embedding_dim, emb_wts, non_trainable)             
        self.gru = nn.GRU(input_dim, hidden_dim, n_layers_gru, batch_first=True, dropout=drop_prob) #needs batchfirst??
        self.relu = nn.ReLU()
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.softmax = nn.LogSoftmax
        
        #we need to use a softmax loss function...
        
    def forward(self, x, h):
        self.embedding(x)
        out, h = self.gru(x, h)
        out = self.fc(self.relu(out[:,-1]))
        out = self.softmax(out)
        return out, h
    
    def init_hidden(self, batch_size, device):
        weight = next(self.parameters()).data
        hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device)
        return hidden

    

In [None]:
# let's try to run this on some sample data...

# first, create our input examples:
y = [0, 1]
x_sentlist = [["a", "quick", "brown", "fox", "jumps"], ["a", "lazy", "brown", "dog", "barks"]]

#need a tokenizer to map each word to a unique integer. [Let's install keras tokenizer?]


gru = GRUNet()

In [None]:
## next, we need a function that will train this stuff

#first, create dataset & dataloader


#loss function: 
criterion = nn.NLLLoss() ## So apparently this does softmax and NLL?

