In [2]:
import json
import csv
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from tqdm import tqdm_notebook
import torch.optim as optim

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
with open('../../data/processed/masked_full_transcripts.json', 'r') as inp:
    transcripts = json.load(inp)

## Split Data

In [None]:
np.random.shuffle(transcripts)

In [None]:
train_split = len(transcripts)*8//10
valid_split = len(transcripts)*1//10
train = transcripts[:train_split]
valid = transcripts[train_split:train_split+valid_split]
test = transcripts[train_split+valid_split:]

train_csv = [('id', 'transcript', 'post_high')]
train_stock = []
for example in train:
    train_csv.append((example['id'], example['transcript'], float(example['post_high_low'][0])))
    historical = map(lambda x: x[1], sorted(example['historical_info'], key=lambda x: x[0]))
    train_stock.append(list(map(lambda x: list(map(float, x)), historical)))
    train_stock[-1] += [example['market_cap'], float(example['post_high_low'][0])]

valid_csv = [('id', 'transcript', 'post_high')]
valid_stock = []
for example in valid:
    valid_csv.append((example['id'], example['transcript'], float(example['post_high_low'][0])))
    historical = map(lambda x: x[1], sorted(example['historical_info'], key=lambda x: x[0]))
    valid_stock.append(list(map(lambda x: list(map(float, x)), historical)))
    valid_stock[-1] += [example['market_cap'], float(example['post_high_low'][0])]

test_csv = [('id', 'transcript', 'post_high')]
test_stock = []
for example in test:
    test_csv.append((example['id'], example['transcript'], float(example['post_high_low'][0])))
    historical = map(lambda x: x[1], sorted(example['historical_info'], key=lambda x: x[0]))
    test_stock.append(list(map(lambda x: list(map(float, x)), historical)))
    test_stock[-1] += [example['market_cap'], float(example['post_high_low'][0])]

In [None]:
with open('../../data/processed/splits/train/transcripts.csv', 'w') as out:
    writer = csv.writer(out)
    header = None
    for row in train_csv:
        if header is None:
            writer.writerow(row)
            header = row
        else:
            [id, transcript, label] = row
            writer.writerow([id, ' '.join(transcript), label])

with open('../../data/processed/splits/valid/transcripts.csv', 'w') as out:
    writer = csv.writer(out)
    header = None
    for row in valid_csv:
        if header is None:
            writer.writerow(row)
            header = row
        else:
            [id, transcript, label] = row
            writer.writerow([id, ' '.join(transcript), label])

with open('../../data/processed/splits/test/transcripts.csv', 'w') as out:
    writer = csv.writer(out)
    header = None
    for row in test_csv:
        if header is None:
            writer.writerow(row)
            header = row
        else:
            [id, transcript, label] = row
            writer.writerow([id, ' '.join(transcript), label])

In [None]:
with open('../../data/processed/splits/train/stock_data.json', 'w') as out:
    json.dump(train_stock, out)
with open('../../data/processed/splits/valid/stock_data.json', 'w') as out:
    json.dump(valid_stock, out)
with open('../../data/processed/splits/test/stock_data.json', 'w') as out:
    json.dump(test_stock, out)

In [None]:
with open('../../data/processed/splits/valid/stock_data.csv', 'w') as out:
    writer = csv.writer(out)
    for row in valid_stock:
        writer.writerow(row)

## Load Data

In [5]:
with open('../../data/processed/splits/train/stock_data.json', 'r') as inp:
    train_stock = json.load(inp)
with open('../../data/processed/splits/valid/stock_data.json', 'r') as inp:
    valid_stock = json.load(inp)
with open('../../data/processed/splits/test/stock_data.json', 'r') as inp:
    test_stock = json.load(inp)

In [6]:
train_transcripts = []
with open('../../data/processed/splits/train/transcripts.csv', 'r') as inp:
    reader = csv.reader(inp)
    headers = None
    for row in reader:
        if headers is None:
            headers = row
        else:
            row[1] = row[1].split(' ')
            train_transcripts.append(row)

## Create Dataset

In [6]:
class StockDataset(data.Dataset):
    def __init__(self, examples):
        examples = np.array(examples)
        self.labels = examples[:,-1]
        self.market_cap = examples[:,-2]
        self.examples = np.array(examples[:,:-2].tolist())

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.examples)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Load data and get label
        X = torch.tensor(self.examples[index])
        auxiliary = self.market_cap[index]
        y = self.labels[index]

        return X, auxiliary, y

In [7]:
valid_dataset = StockDataset(valid_stock)

In [8]:
params = {'batch_size': 64,
          'shuffle': True,
          'num_workers': 6}
valid_generator = data.DataLoader(valid_dataset, **params)

## Train Model

In [9]:
class BaselineStockPredictor(nn.Module):
    """
    Model that will read in plain stock ticker values over time and decide whether to buy, sell, or hold at the current price.
    """
    def __init__(self, num_series_features=1, num_auxiliary_features=1, hidden_size=128, output_size=1):
        """
        Attributes:
            num_series_features: the size of the feature set for an individual
                                 stock price example (e.g. if we include high,
                                 low, average, num_series_features will equal 3
            num_auxiliary_features: the number of auxiliary (not dependent on time)
                                    features we are adding (e.g. if we include the 1yr
                                    high and the market capitalization, num_auxiliary_features
                                    would equal 2
            output_size: the size of the outputted vector. For evaluation, we would use a
                         size of 1 (stock price) or 3 (buy, sell, hold classification).
                         For use in the looking glass model, we want an encoding so we might
                         use a size of 128 to feed into the model.
        """
        super().__init__()
        self.recurrent = nn.LSTM(
            input_size=num_series_features,
            hidden_size=hidden_size,
            num_layers=1,
            bidirectional=False,
            batch_first=True,
            dropout=0.5
        )
        # concatenate LSTM output with auxiliary features
        # output predicted price
        self.linear = nn.Linear(hidden_size+num_auxiliary_features, output_size)
        self.init_weights()

    def init_weights(self):
        """
        Initializes the weights of the model
        """
        for layer in [self.linear]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.constant_(layer.bias, 0.0)

    def forward(self, X_series, X_auxiliary):
        """
        Moves the model through each layer
        Parameters:
            X_series: an [N, num_series_examples, num_series_features] size vector
                      where N is the batch size, num_series_examples is how many stock prices
                      we are providing per example (e.g. weekly for the last 3 months), and
                      num_series_features is the same as described in __init__
            X_auxiliary: an [N, num_auxiliary_features] vector
        """
        recurrent_output,_ = self.recurrent(X_series)
        recurrent_output = torch.mean(recurrent_output, 1)
        # We might need this
        # recurrent_output = torch.squeeze(1) 
        aux_combined = torch.cat([recurrent_output, X_auxiliary], dim=1)
        output = self.linear(aux_combined)

        return output

In [10]:
def get_stock_iterator(input_data, batch_size, train=True, shuffle=True):
    dataset = StockDataset(input_data)
    iterator = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=5)
    return iterator
    
def train_model(train, valid, num_epochs=200, learning_rate=0.003, existing_model=None):
    batch_size = 128
    train_iterator = get_stock_iterator(train, batch_size)
    valid_iterator = get_stock_iterator(valid, batch_size)
    
    model = None
    
    if existing_model is None:
        model = BaselineStockPredictor(num_series_features=2, hidden_size=64)
    else:
        model = existing_model
    model = model.float().cuda()
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), learning_rate)
    
    losses = []
    valid_scores = []
    
    min_mse = float('inf')
    delay = 0
    MAX_INC = 20
    
    for epoch in tqdm_notebook(range(num_epochs)):
        model.train()
        iter_losses = []
        for batch_series, batch_aux, batch_labels in train_iterator:
            batch_aux = torch.reshape(batch_aux, (-1,1))
            optimizer.zero_grad()
            outputs = model(batch_series.float().cuda(), batch_aux.float().cuda())
            batch_labels = torch.reshape(batch_labels, (-1,1))
            loss = criterion(outputs, batch_labels.float().cuda())
            loss.backward()
            optimizer.step()
            iter_losses.append(loss.item())
        iter_losses = np.array(iter_losses)
        losses.append(np.mean(iter_losses))
        
        valid_mse = []
        model.eval()

        for valid_batch_series, valid_batch_aux, valid_batch_labels in valid_iterator:
            valid_batch_aux = torch.reshape(valid_batch_aux, (-1,1))
            outputs = model(valid_batch_series.float().cuda(), valid_batch_aux.float().cuda())
            valid_batch_labels = torch.reshape(valid_batch_labels, (-1,1))
            loss = criterion(outputs, valid_batch_labels.float().cuda())
            valid_mse.append(loss.item())
        valid_mse = np.mean(valid_mse)
        print(f'Completed epoch {epoch}. Valid MSE: {valid_mse}')


        if valid_mse < min_mse:
            min_mse = valid_mse
            delay = 0
            torch.save(model, 'model.ckpt')
        else:
            delay += 1
    
    return model, losses

In [None]:
model, losses = train_model(train_stock, valid_stock, num_epochs=10000)

  "num_layers={}".format(dropout, num_layers))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

Completed epoch 0. Valid MSE: 32013.314453125


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Completed epoch 1. Valid MSE: 31528.376953125
Completed epoch 2. Valid MSE: 31025.55078125
Completed epoch 3. Valid MSE: 30445.822265625
Completed epoch 4. Valid MSE: 29766.986328125
Completed epoch 5. Valid MSE: 29127.77734375
Completed epoch 6. Valid MSE: 28701.177734375
Completed epoch 7. Valid MSE: 28284.650390625
Completed epoch 8. Valid MSE: 27831.666015625
Completed epoch 9. Valid MSE: 27359.376953125
Completed epoch 10. Valid MSE: 26891.03515625
Completed epoch 11. Valid MSE: 26474.033203125
Completed epoch 12. Valid MSE: 26082.15625
Completed epoch 13. Valid MSE: 25684.43359375
Completed epoch 14. Valid MSE: 25303.71875
Completed epoch 15. Valid MSE: 24948.28515625
Completed epoch 16. Valid MSE: 24620.185546875
Completed epoch 17. Valid MSE: 24306.05859375
Completed epoch 18. Valid MSE: 23996.466796875
Completed epoch 19. Valid MSE: 23703.619140625
Completed epoch 20. Valid MSE: 23427.94921875
Completed epoch 21. Valid MSE: 23163.13671875
Completed epoch 22. Valid MSE: 22913.3

Completed epoch 176. Valid MSE: 13800.33203125
Completed epoch 177. Valid MSE: 13794.4501953125
Completed epoch 178. Valid MSE: 13777.908203125
Completed epoch 179. Valid MSE: 13702.8359375
Completed epoch 180. Valid MSE: 13680.2197265625
Completed epoch 181. Valid MSE: 13674.9443359375
Completed epoch 182. Valid MSE: 13624.259765625
Completed epoch 183. Valid MSE: 13601.234375
Completed epoch 184. Valid MSE: 13547.6943359375
Completed epoch 185. Valid MSE: 13545.5517578125
Completed epoch 186. Valid MSE: 13454.64453125
Completed epoch 187. Valid MSE: 13449.7109375
Completed epoch 188. Valid MSE: 13516.3388671875
Completed epoch 189. Valid MSE: 13554.9287109375
Completed epoch 190. Valid MSE: 13371.466796875
Completed epoch 191. Valid MSE: 13316.8642578125
Completed epoch 192. Valid MSE: 13288.8193359375
Completed epoch 193. Valid MSE: 13336.9775390625
Completed epoch 194. Valid MSE: 13289.2138671875
Completed epoch 195. Valid MSE: 13183.990234375
Completed epoch 196. Valid MSE: 13140.

Completed epoch 348. Valid MSE: 8766.77734375
Completed epoch 349. Valid MSE: 8872.3896484375
Completed epoch 350. Valid MSE: 8618.6875
Completed epoch 351. Valid MSE: 8625.7158203125
Completed epoch 352. Valid MSE: 8721.6396484375
Completed epoch 353. Valid MSE: 8640.826171875
Completed epoch 354. Valid MSE: 8528.97265625
Completed epoch 355. Valid MSE: 8600.7724609375
Completed epoch 356. Valid MSE: 8550.91015625
Completed epoch 357. Valid MSE: 8533.328125
Completed epoch 358. Valid MSE: 8503.1572265625
Completed epoch 359. Valid MSE: 8568.861328125
Completed epoch 360. Valid MSE: 8520.5009765625
Completed epoch 361. Valid MSE: 8483.0576171875
Completed epoch 362. Valid MSE: 8453.763671875
Completed epoch 363. Valid MSE: 8376.0595703125
Completed epoch 364. Valid MSE: 8392.1904296875
Completed epoch 365. Valid MSE: 8528.9853515625
Completed epoch 366. Valid MSE: 8375.078125
Completed epoch 367. Valid MSE: 8303.7080078125
Completed epoch 368. Valid MSE: 8266.54296875
Completed epoch 3

Completed epoch 520. Valid MSE: 5870.5576171875
Completed epoch 521. Valid MSE: 5805.90283203125
Completed epoch 522. Valid MSE: 5779.8232421875
Completed epoch 523. Valid MSE: 5851.97509765625
Completed epoch 524. Valid MSE: 5727.14794921875
Completed epoch 525. Valid MSE: 5796.298828125
Completed epoch 526. Valid MSE: 5745.3759765625
Completed epoch 527. Valid MSE: 5751.966796875
Completed epoch 528. Valid MSE: 5733.169921875
Completed epoch 529. Valid MSE: 5725.5869140625
Completed epoch 530. Valid MSE: 5687.408203125
Completed epoch 531. Valid MSE: 5674.13623046875
Completed epoch 532. Valid MSE: 5704.22412109375
Completed epoch 533. Valid MSE: 5649.203125
Completed epoch 534. Valid MSE: 5609.21923828125
Completed epoch 535. Valid MSE: 5664.92626953125
Completed epoch 536. Valid MSE: 5602.74365234375
Completed epoch 537. Valid MSE: 5611.1640625
Completed epoch 538. Valid MSE: 5603.26123046875
Completed epoch 539. Valid MSE: 5640.4453125
Completed epoch 540. Valid MSE: 5541.79638671

Completed epoch 691. Valid MSE: 3852.796142578125
Completed epoch 692. Valid MSE: 3871.322998046875
Completed epoch 693. Valid MSE: 3836.216552734375
Completed epoch 694. Valid MSE: 3896.484130859375
Completed epoch 695. Valid MSE: 3913.3701171875
Completed epoch 696. Valid MSE: 3937.36669921875
Completed epoch 697. Valid MSE: 3854.525634765625
Completed epoch 698. Valid MSE: 3857.8486328125
Completed epoch 699. Valid MSE: 3821.9169921875
Completed epoch 700. Valid MSE: 3883.4443359375
Completed epoch 701. Valid MSE: 3766.822998046875
Completed epoch 702. Valid MSE: 3774.73876953125
Completed epoch 703. Valid MSE: 3801.283935546875
Completed epoch 704. Valid MSE: 3757.179931640625
Completed epoch 705. Valid MSE: 3861.23193359375
Completed epoch 706. Valid MSE: 3689.995849609375
Completed epoch 707. Valid MSE: 3819.53125
Completed epoch 708. Valid MSE: 3792.283935546875
Completed epoch 709. Valid MSE: 3690.72802734375
Completed epoch 710. Valid MSE: 3832.3125
Completed epoch 711. Valid 

Completed epoch 859. Valid MSE: 2740.330810546875
Completed epoch 860. Valid MSE: 2727.54931640625
Completed epoch 861. Valid MSE: 2729.172119140625
Completed epoch 862. Valid MSE: 2647.260498046875
Completed epoch 863. Valid MSE: 2692.14306640625
Completed epoch 864. Valid MSE: 2665.68017578125
Completed epoch 865. Valid MSE: 2667.65673828125
Completed epoch 866. Valid MSE: 2705.825439453125
Completed epoch 867. Valid MSE: 2682.549072265625
Completed epoch 868. Valid MSE: 2696.625732421875
Completed epoch 869. Valid MSE: 2607.1201171875
Completed epoch 870. Valid MSE: 2651.806640625
Completed epoch 871. Valid MSE: 2608.539306640625
Completed epoch 872. Valid MSE: 2586.0556640625
Completed epoch 873. Valid MSE: 2662.758056640625
Completed epoch 874. Valid MSE: 2636.532958984375
Completed epoch 875. Valid MSE: 2580.832275390625
Completed epoch 876. Valid MSE: 2714.036376953125
Completed epoch 877. Valid MSE: 2721.005615234375
Completed epoch 878. Valid MSE: 2610.40478515625
Completed ep

Completed epoch 1025. Valid MSE: 2022.6533203125
Completed epoch 1026. Valid MSE: 2070.04052734375
Completed epoch 1027. Valid MSE: 2012.9674072265625
Completed epoch 1028. Valid MSE: 2032.5123291015625
Completed epoch 1029. Valid MSE: 1984.341552734375
Completed epoch 1030. Valid MSE: 2099.9853515625
Completed epoch 1031. Valid MSE: 2020.5081787109375
Completed epoch 1032. Valid MSE: 2004.750732421875
Completed epoch 1033. Valid MSE: 2029.965576171875
Completed epoch 1034. Valid MSE: 2023.6690673828125
Completed epoch 1035. Valid MSE: 2016.10693359375
Completed epoch 1036. Valid MSE: 1976.387451171875
Completed epoch 1037. Valid MSE: 1971.634521484375
Completed epoch 1038. Valid MSE: 1925.837158203125
Completed epoch 1039. Valid MSE: 1938.1982421875
Completed epoch 1040. Valid MSE: 1959.4937744140625
Completed epoch 1041. Valid MSE: 1965.589599609375
Completed epoch 1042. Valid MSE: 1921.203125
Completed epoch 1043. Valid MSE: 1953.8182373046875
Completed epoch 1044. Valid MSE: 1904.2

Completed epoch 1186. Valid MSE: 1658.136474609375
Completed epoch 1187. Valid MSE: 1732.524169921875
Completed epoch 1188. Valid MSE: 1639.914794921875
Completed epoch 1189. Valid MSE: 1715.532958984375
Completed epoch 1190. Valid MSE: 1681.013916015625
Completed epoch 1191. Valid MSE: 1636.0999755859375
Completed epoch 1192. Valid MSE: 1676.8409423828125
Completed epoch 1193. Valid MSE: 1679.380615234375
Completed epoch 1194. Valid MSE: 1646.8232421875
Completed epoch 1195. Valid MSE: 1694.2364501953125
Completed epoch 1196. Valid MSE: 1705.206787109375
Completed epoch 1197. Valid MSE: 1637.1380615234375
Completed epoch 1198. Valid MSE: 1664.8934326171875
Completed epoch 1199. Valid MSE: 1667.46630859375
Completed epoch 1200. Valid MSE: 1631.9642333984375
Completed epoch 1201. Valid MSE: 1680.9986572265625
Completed epoch 1202. Valid MSE: 1665.3033447265625
Completed epoch 1203. Valid MSE: 1624.7808837890625
Completed epoch 1204. Valid MSE: 1709.2174072265625
Completed epoch 1205. Va

Completed epoch 1346. Valid MSE: 1448.96142578125
Completed epoch 1347. Valid MSE: 1449.1627197265625
Completed epoch 1348. Valid MSE: 1441.665283203125
Completed epoch 1349. Valid MSE: 1452.02587890625
Completed epoch 1350. Valid MSE: 1466.71142578125
Completed epoch 1351. Valid MSE: 1485.425048828125
Completed epoch 1352. Valid MSE: 1463.8218994140625
Completed epoch 1353. Valid MSE: 1443.3658447265625
Completed epoch 1354. Valid MSE: 1426.986328125
Completed epoch 1355. Valid MSE: 1367.936279296875
Completed epoch 1356. Valid MSE: 1392.190673828125
Completed epoch 1357. Valid MSE: 1370.7039794921875
Completed epoch 1358. Valid MSE: 1398.533447265625
Completed epoch 1359. Valid MSE: 1389.9530029296875
Completed epoch 1360. Valid MSE: 1410.117919921875
Completed epoch 1361. Valid MSE: 1420.25048828125
Completed epoch 1362. Valid MSE: 1423.0479736328125
Completed epoch 1363. Valid MSE: 1432.817138671875
Completed epoch 1364. Valid MSE: 1423.50732421875
Completed epoch 1365. Valid MSE: 

Completed epoch 1508. Valid MSE: 1102.7664794921875
Completed epoch 1509. Valid MSE: 1094.5611572265625
Completed epoch 1510. Valid MSE: 1076.8907470703125
Completed epoch 1511. Valid MSE: 1091.836669921875
Completed epoch 1512. Valid MSE: 1082.3487548828125
Completed epoch 1513. Valid MSE: 1062.9674072265625
Completed epoch 1514. Valid MSE: 1056.98974609375
Completed epoch 1515. Valid MSE: 1081.2989501953125
Completed epoch 1516. Valid MSE: 1094.3095703125
Completed epoch 1517. Valid MSE: 1064.3935546875
Completed epoch 1518. Valid MSE: 1050.22265625
Completed epoch 1519. Valid MSE: 1078.0423583984375
Completed epoch 1520. Valid MSE: 1080.98974609375
Completed epoch 1521. Valid MSE: 1077.1390380859375
Completed epoch 1522. Valid MSE: 1036.764892578125
Completed epoch 1523. Valid MSE: 1020.1644287109375
Completed epoch 1524. Valid MSE: 1041.6900634765625
Completed epoch 1525. Valid MSE: 1047.681396484375
Completed epoch 1526. Valid MSE: 1049.4447021484375
Completed epoch 1527. Valid MS

Completed epoch 1670. Valid MSE: 891.5576171875
Completed epoch 1671. Valid MSE: 875.91259765625
Completed epoch 1672. Valid MSE: 881.6649169921875
Completed epoch 1673. Valid MSE: 874.0770263671875
Completed epoch 1674. Valid MSE: 896.9890747070312
Completed epoch 1675. Valid MSE: 894.3906860351562
Completed epoch 1676. Valid MSE: 888.6007080078125
Completed epoch 1677. Valid MSE: 902.966552734375
Completed epoch 1678. Valid MSE: 894.1608276367188
Completed epoch 1679. Valid MSE: 866.2897338867188
Completed epoch 1680. Valid MSE: 808.2437744140625
Completed epoch 1681. Valid MSE: 840.9613647460938
Completed epoch 1682. Valid MSE: 834.7062377929688
Completed epoch 1683. Valid MSE: 854.2156372070312
Completed epoch 1684. Valid MSE: 987.9962768554688
Completed epoch 1685. Valid MSE: 858.302978515625
Completed epoch 1686. Valid MSE: 792.0970458984375
Completed epoch 1687. Valid MSE: 834.4796752929688
Completed epoch 1688. Valid MSE: 829.4351806640625
Completed epoch 1689. Valid MSE: 785.3

Completed epoch 1833. Valid MSE: 968.541748046875
Completed epoch 1834. Valid MSE: 977.2929077148438
Completed epoch 1835. Valid MSE: 991.4530639648438
Completed epoch 1836. Valid MSE: 938.1244506835938
Completed epoch 1837. Valid MSE: 852.9744873046875
Completed epoch 1838. Valid MSE: 800.3720703125
Completed epoch 1839. Valid MSE: 858.7359619140625
Completed epoch 1840. Valid MSE: 844.109375
Completed epoch 1841. Valid MSE: 852.056396484375
Completed epoch 1842. Valid MSE: 827.3729858398438
Completed epoch 1843. Valid MSE: 826.6755981445312
Completed epoch 1844. Valid MSE: 832.0620727539062
Completed epoch 1845. Valid MSE: 833.3077392578125
Completed epoch 1846. Valid MSE: 845.06396484375
Completed epoch 1847. Valid MSE: 846.177978515625
Completed epoch 1848. Valid MSE: 845.856689453125
Completed epoch 1849. Valid MSE: 927.3645629882812
Completed epoch 1850. Valid MSE: 940.4053955078125
Completed epoch 1851. Valid MSE: 874.6256103515625
Completed epoch 1852. Valid MSE: 807.7547607421

Completed epoch 1996. Valid MSE: 671.634765625
Completed epoch 1997. Valid MSE: 741.5770874023438
Completed epoch 1998. Valid MSE: 920.4305419921875
Completed epoch 1999. Valid MSE: 911.6782836914062
Completed epoch 2000. Valid MSE: 909.0791015625
Completed epoch 2001. Valid MSE: 790.7883911132812
Completed epoch 2002. Valid MSE: 767.7434692382812
Completed epoch 2003. Valid MSE: 786.1893920898438
Completed epoch 2004. Valid MSE: 700.6474609375
Completed epoch 2005. Valid MSE: 599.4196166992188
Completed epoch 2006. Valid MSE: 529.2246704101562
Completed epoch 2007. Valid MSE: 554.7487182617188
Completed epoch 2008. Valid MSE: 633.7501831054688
Completed epoch 2009. Valid MSE: 584.96875
Completed epoch 2010. Valid MSE: 544.6824951171875
Completed epoch 2011. Valid MSE: 561.4859619140625
Completed epoch 2012. Valid MSE: 576.0030517578125
Completed epoch 2013. Valid MSE: 617.8872680664062
Completed epoch 2014. Valid MSE: 648.9454956054688
Completed epoch 2015. Valid MSE: 651.78564453125


Completed epoch 2158. Valid MSE: 989.7212524414062
Completed epoch 2159. Valid MSE: 1005.6240234375
Completed epoch 2160. Valid MSE: 970.1360473632812
Completed epoch 2161. Valid MSE: 972.6998901367188
Completed epoch 2162. Valid MSE: 968.0787353515625
Completed epoch 2163. Valid MSE: 966.4549560546875
Completed epoch 2164. Valid MSE: 997.157470703125
Completed epoch 2165. Valid MSE: 976.0463256835938
Completed epoch 2166. Valid MSE: 934.32666015625
Completed epoch 2167. Valid MSE: 970.15869140625
Completed epoch 2168. Valid MSE: 955.4344482421875
Completed epoch 2169. Valid MSE: 933.7378540039062
Completed epoch 2170. Valid MSE: 948.9771728515625
Completed epoch 2171. Valid MSE: 963.5892944335938
Completed epoch 2172. Valid MSE: 856.5687255859375
Completed epoch 2173. Valid MSE: 891.4345703125
Completed epoch 2174. Valid MSE: 888.7779541015625
Completed epoch 2175. Valid MSE: 881.4786376953125
Completed epoch 2176. Valid MSE: 919.8724975585938
Completed epoch 2177. Valid MSE: 902.6747

Completed epoch 2320. Valid MSE: 594.8217163085938
Completed epoch 2321. Valid MSE: 580.0476684570312
Completed epoch 2322. Valid MSE: 584.0440063476562
Completed epoch 2323. Valid MSE: 624.3782348632812
Completed epoch 2324. Valid MSE: 593.0185546875
Completed epoch 2325. Valid MSE: 598.1475830078125
Completed epoch 2326. Valid MSE: 570.9481811523438
Completed epoch 2327. Valid MSE: 557.9219360351562
Completed epoch 2328. Valid MSE: 557.3020629882812
Completed epoch 2329. Valid MSE: 560.8330688476562
Completed epoch 2330. Valid MSE: 545.9576416015625
Completed epoch 2331. Valid MSE: 566.433837890625
Completed epoch 2332. Valid MSE: 546.0640869140625
Completed epoch 2333. Valid MSE: 590.482177734375
Completed epoch 2334. Valid MSE: 665.0724487304688
Completed epoch 2335. Valid MSE: 752.6282958984375
Completed epoch 2336. Valid MSE: 954.5031127929688
Completed epoch 2337. Valid MSE: 922.0534057617188
Completed epoch 2338. Valid MSE: 854.5374145507812
Completed epoch 2339. Valid MSE: 854

Completed epoch 2482. Valid MSE: 1295.4420166015625
Completed epoch 2483. Valid MSE: 1302.529052734375
Completed epoch 2484. Valid MSE: 1270.8475341796875
Completed epoch 2485. Valid MSE: 1236.93603515625
Completed epoch 2486. Valid MSE: 1232.9757080078125
Completed epoch 2487. Valid MSE: 1235.04150390625
Completed epoch 2488. Valid MSE: 1249.0848388671875
Completed epoch 2489. Valid MSE: 1249.7279052734375
Completed epoch 2490. Valid MSE: 1255.102294921875
Completed epoch 2491. Valid MSE: 1256.5662841796875
Completed epoch 2492. Valid MSE: 1265.621826171875
Completed epoch 2493. Valid MSE: 1275.4688720703125
Completed epoch 2494. Valid MSE: 1276.353759765625
Completed epoch 2495. Valid MSE: 1279.060546875
Completed epoch 2496. Valid MSE: 1273.6695556640625
Completed epoch 2497. Valid MSE: 1284.868408203125
Completed epoch 2498. Valid MSE: 1278.418212890625
Completed epoch 2499. Valid MSE: 1290.75341796875
Completed epoch 2500. Valid MSE: 1297.11083984375
Completed epoch 2501. Valid MS

Completed epoch 2644. Valid MSE: 554.3786010742188
Completed epoch 2645. Valid MSE: 571.2216186523438
Completed epoch 2646. Valid MSE: 553.0385131835938
Completed epoch 2647. Valid MSE: 553.4739379882812
Completed epoch 2648. Valid MSE: 568.2039184570312
Completed epoch 2649. Valid MSE: 565.4383544921875
Completed epoch 2650. Valid MSE: 550.084716796875
Completed epoch 2651. Valid MSE: 557.3701171875
Completed epoch 2652. Valid MSE: 579.7656860351562
Completed epoch 2653. Valid MSE: 561.5717163085938
Completed epoch 2654. Valid MSE: 577.8606567382812
Completed epoch 2655. Valid MSE: 574.4667358398438
Completed epoch 2656. Valid MSE: 580.28662109375
Completed epoch 2657. Valid MSE: 618.9525756835938
Completed epoch 2658. Valid MSE: 587.6002807617188
Completed epoch 2659. Valid MSE: 583.2356567382812
Completed epoch 2660. Valid MSE: 585.6729125976562
Completed epoch 2661. Valid MSE: 569.1050415039062
Completed epoch 2662. Valid MSE: 573.1358032226562
Completed epoch 2663. Valid MSE: 593.

Completed epoch 2806. Valid MSE: 506.850341796875
Completed epoch 2807. Valid MSE: 491.83697509765625
Completed epoch 2808. Valid MSE: 540.2608642578125
Completed epoch 2809. Valid MSE: 470.1847229003906
Completed epoch 2810. Valid MSE: 412.9174499511719
Completed epoch 2811. Valid MSE: 447.5611877441406
Completed epoch 2812. Valid MSE: 432.9775695800781
Completed epoch 2813. Valid MSE: 396.5976257324219
Completed epoch 2814. Valid MSE: 358.8153381347656
Completed epoch 2815. Valid MSE: 350.9983825683594
Completed epoch 2816. Valid MSE: 388.1081237792969
Completed epoch 2817. Valid MSE: 464.0162353515625
Completed epoch 2818. Valid MSE: 602.8421630859375
Completed epoch 2819. Valid MSE: 532.9993286132812
Completed epoch 2820. Valid MSE: 579.38232421875
Completed epoch 2821. Valid MSE: 836.3275756835938
Completed epoch 2822. Valid MSE: 743.2916870117188
Completed epoch 2823. Valid MSE: 749.203369140625
Completed epoch 2824. Valid MSE: 545.0176391601562
Completed epoch 2825. Valid MSE: 8

Completed epoch 2967. Valid MSE: 403.3105163574219
Completed epoch 2968. Valid MSE: 468.4688720703125
Completed epoch 2969. Valid MSE: 489.848388671875
Completed epoch 2970. Valid MSE: 401.99835205078125
Completed epoch 2971. Valid MSE: 401.0466003417969
Completed epoch 2972. Valid MSE: 393.8580627441406
Completed epoch 2973. Valid MSE: 407.8066101074219
Completed epoch 2974. Valid MSE: 436.9112854003906
Completed epoch 2975. Valid MSE: 443.3142395019531
Completed epoch 2976. Valid MSE: 433.5289611816406
Completed epoch 2977. Valid MSE: 446.49652099609375
Completed epoch 2978. Valid MSE: 451.7575378417969
Completed epoch 2979. Valid MSE: 444.0293273925781
Completed epoch 2980. Valid MSE: 439.80389404296875
Completed epoch 2981. Valid MSE: 453.53826904296875
Completed epoch 2982. Valid MSE: 436.6923522949219
Completed epoch 2983. Valid MSE: 441.8773498535156
Completed epoch 2984. Valid MSE: 451.7403259277344
Completed epoch 2985. Valid MSE: 451.0227355957031
Completed epoch 2986. Valid 

Completed epoch 3130. Valid MSE: 410.34515380859375
Completed epoch 3131. Valid MSE: 399.5864562988281
Completed epoch 3132. Valid MSE: 459.12969970703125
Completed epoch 3133. Valid MSE: 576.970458984375
Completed epoch 3134. Valid MSE: 590.9109497070312
Completed epoch 3135. Valid MSE: 492.8280334472656
Completed epoch 3136. Valid MSE: 427.2388000488281
Completed epoch 3137. Valid MSE: 449.7169189453125
Completed epoch 3138. Valid MSE: 371.6346740722656
Completed epoch 3139. Valid MSE: 465.5673522949219
Completed epoch 3140. Valid MSE: 551.53173828125
Completed epoch 3141. Valid MSE: 619.4243774414062
Completed epoch 3142. Valid MSE: 839.8316650390625
Completed epoch 3143. Valid MSE: 909.4290161132812
Completed epoch 3144. Valid MSE: 1555.5496826171875
Completed epoch 3145. Valid MSE: 1717.129638671875
Completed epoch 3146. Valid MSE: 1320.8052978515625
Completed epoch 3147. Valid MSE: 1173.49365234375
Completed epoch 3148. Valid MSE: 1157.147705078125
Completed epoch 3149. Valid MSE

Completed epoch 3291. Valid MSE: 1275.9410400390625
Completed epoch 3292. Valid MSE: 1264.27490234375
Completed epoch 3293. Valid MSE: 1244.831298828125
Completed epoch 3294. Valid MSE: 1246.2486572265625
Completed epoch 3295. Valid MSE: 1254.75
Completed epoch 3296. Valid MSE: 1262.3486328125
Completed epoch 3297. Valid MSE: 1293.2357177734375
Completed epoch 3298. Valid MSE: 1298.110595703125
Completed epoch 3299. Valid MSE: 1313.9674072265625
Completed epoch 3300. Valid MSE: 1312.33203125
Completed epoch 3301. Valid MSE: 1305.97998046875
Completed epoch 3302. Valid MSE: 1293.6884765625
Completed epoch 3303. Valid MSE: 1292.08154296875
Completed epoch 3304. Valid MSE: 1284.3074951171875
Completed epoch 3305. Valid MSE: 1284.739501953125
Completed epoch 3306. Valid MSE: 1276.7734375
Completed epoch 3307. Valid MSE: 1280.5450439453125
Completed epoch 3308. Valid MSE: 1287.9964599609375
Completed epoch 3309. Valid MSE: 1284.8790283203125
Completed epoch 3310. Valid MSE: 1291.39709472656

Completed epoch 3452. Valid MSE: 1265.160888671875
Completed epoch 3453. Valid MSE: 1297.1136474609375
Completed epoch 3454. Valid MSE: 1316.9112548828125
Completed epoch 3455. Valid MSE: 1332.919921875
Completed epoch 3456. Valid MSE: 1310.5438232421875
Completed epoch 3457. Valid MSE: 1350.717529296875
Completed epoch 3458. Valid MSE: 1387.279541015625
Completed epoch 3459. Valid MSE: 1407.8057861328125
Completed epoch 3460. Valid MSE: 1406.966064453125
Completed epoch 3461. Valid MSE: 1387.14111328125
Completed epoch 3462. Valid MSE: 1385.273193359375
Completed epoch 3463. Valid MSE: 1362.458740234375
Completed epoch 3464. Valid MSE: 1362.4609375
Completed epoch 3465. Valid MSE: 1384.458984375
Completed epoch 3466. Valid MSE: 1384.6326904296875
Completed epoch 3467. Valid MSE: 1399.4676513671875
Completed epoch 3468. Valid MSE: 1388.5208740234375
Completed epoch 3469. Valid MSE: 1382.841552734375
Completed epoch 3470. Valid MSE: 1395.0560302734375
Completed epoch 3471. Valid MSE: 13

Completed epoch 3613. Valid MSE: 279.84765625
Completed epoch 3614. Valid MSE: 222.54147338867188
Completed epoch 3615. Valid MSE: 218.0221405029297
Completed epoch 3616. Valid MSE: 213.38140869140625
Completed epoch 3617. Valid MSE: 266.519775390625
Completed epoch 3618. Valid MSE: 274.6725769042969
Completed epoch 3619. Valid MSE: 275.2107238769531
Completed epoch 3620. Valid MSE: 363.4243469238281
Completed epoch 3621. Valid MSE: 292.19256591796875
Completed epoch 3622. Valid MSE: 280.3585510253906
Completed epoch 3623. Valid MSE: 602.7603149414062
Completed epoch 3624. Valid MSE: 626.1961669921875
Completed epoch 3625. Valid MSE: 747.1832275390625
Completed epoch 3626. Valid MSE: 587.4716186523438
Completed epoch 3627. Valid MSE: 294.8828125
Completed epoch 3628. Valid MSE: 322.4263916015625
Completed epoch 3629. Valid MSE: 380.8372497558594
Completed epoch 3630. Valid MSE: 350.3824768066406
Completed epoch 3631. Valid MSE: 287.6815490722656
Completed epoch 3632. Valid MSE: 214.957

Completed epoch 3775. Valid MSE: 1334.734619140625
Completed epoch 3776. Valid MSE: 1335.2208251953125
Completed epoch 3777. Valid MSE: 1355.972412109375
Completed epoch 3778. Valid MSE: 1352.589111328125
Completed epoch 3779. Valid MSE: 1367.1348876953125
Completed epoch 3780. Valid MSE: 1389.0318603515625
Completed epoch 3781. Valid MSE: 1390.455078125
Completed epoch 3782. Valid MSE: 1379.979736328125
Completed epoch 3783. Valid MSE: 1381.7979736328125
Completed epoch 3784. Valid MSE: 1377.22509765625
Completed epoch 3785. Valid MSE: 1371.81787109375
Completed epoch 3786. Valid MSE: 1380.68994140625
Completed epoch 3787. Valid MSE: 1395.903076171875
Completed epoch 3788. Valid MSE: 1399.060302734375
Completed epoch 3789. Valid MSE: 1386.0531005859375
Completed epoch 3790. Valid MSE: 1398.267578125
Completed epoch 3791. Valid MSE: 1402.08154296875
Completed epoch 3792. Valid MSE: 1410.999755859375
Completed epoch 3793. Valid MSE: 1407.206787109375
Completed epoch 3794. Valid MSE: 140

Completed epoch 3936. Valid MSE: 1773.5521240234375
Completed epoch 3937. Valid MSE: 1792.14501953125
Completed epoch 3938. Valid MSE: 1796.0767822265625
Completed epoch 3939. Valid MSE: 1786.5218505859375
Completed epoch 3940. Valid MSE: 1781.714111328125
Completed epoch 3941. Valid MSE: 2006.259521484375
Completed epoch 3942. Valid MSE: 2047.8350830078125
Completed epoch 3943. Valid MSE: 2081.353759765625
Completed epoch 3944. Valid MSE: 2018.0086669921875
Completed epoch 3945. Valid MSE: 2044.0164794921875
Completed epoch 3946. Valid MSE: 2056.863037109375
Completed epoch 3947. Valid MSE: 2197.50927734375
Completed epoch 3948. Valid MSE: 2097.998291015625
Completed epoch 3949. Valid MSE: 2199.673583984375
Completed epoch 3950. Valid MSE: 2634.516845703125
Completed epoch 3951. Valid MSE: 2770.625244140625
Completed epoch 3952. Valid MSE: 2466.562255859375
Completed epoch 3953. Valid MSE: 2241.672119140625
Completed epoch 3954. Valid MSE: 1590.0245361328125
Completed epoch 3955. Vali

Completed epoch 4098. Valid MSE: 688.8971557617188
Completed epoch 4099. Valid MSE: 714.9573974609375
Completed epoch 4100. Valid MSE: 697.1425170898438
Completed epoch 4101. Valid MSE: 698.6249389648438
Completed epoch 4102. Valid MSE: 698.9046020507812
Completed epoch 4103. Valid MSE: 688.8692626953125
Completed epoch 4104. Valid MSE: 687.0731201171875
Completed epoch 4105. Valid MSE: 691.5723876953125
Completed epoch 4106. Valid MSE: 671.7390747070312
Completed epoch 4107. Valid MSE: 674.1051025390625
Completed epoch 4108. Valid MSE: 689.2382202148438
Completed epoch 4109. Valid MSE: 681.8084106445312
Completed epoch 4110. Valid MSE: 688.8060913085938
Completed epoch 4111. Valid MSE: 686.3260498046875
Completed epoch 4112. Valid MSE: 665.6329956054688
Completed epoch 4113. Valid MSE: 669.8432006835938
Completed epoch 4114. Valid MSE: 659.8409423828125
Completed epoch 4115. Valid MSE: 617.3056640625
Completed epoch 4116. Valid MSE: 628.5809326171875
Completed epoch 4117. Valid MSE: 7

Completed epoch 4260. Valid MSE: 1451.4207763671875
Completed epoch 4261. Valid MSE: 1451.77001953125
Completed epoch 4262. Valid MSE: 1412.6385498046875
Completed epoch 4263. Valid MSE: 1400.6405029296875
Completed epoch 4264. Valid MSE: 1405.8670654296875
Completed epoch 4265. Valid MSE: 1383.732666015625
Completed epoch 4266. Valid MSE: 1372.4912109375
Completed epoch 4267. Valid MSE: 1421.7344970703125
Completed epoch 4268. Valid MSE: 1395.0457763671875
Completed epoch 4269. Valid MSE: 1401.2415771484375
Completed epoch 4270. Valid MSE: 1391.9718017578125
Completed epoch 4271. Valid MSE: 1369.8619384765625
Completed epoch 4272. Valid MSE: 1409.048828125
Completed epoch 4273. Valid MSE: 1393.16943359375
Completed epoch 4274. Valid MSE: 1432.943603515625
Completed epoch 4275. Valid MSE: 1472.8712158203125
Completed epoch 4276. Valid MSE: 1461.4736328125
Completed epoch 4277. Valid MSE: 1455.13623046875
Completed epoch 4278. Valid MSE: 1476.86962890625
Completed epoch 4279. Valid MSE:

Completed epoch 4421. Valid MSE: 1968.14794921875
Completed epoch 4422. Valid MSE: 2243.072509765625
Completed epoch 4423. Valid MSE: 2170.485595703125
Completed epoch 4424. Valid MSE: 2023.8292236328125
Completed epoch 4425. Valid MSE: 1905.935302734375
Completed epoch 4426. Valid MSE: 1893.4560546875
Completed epoch 4427. Valid MSE: 1857.0538330078125
Completed epoch 4428. Valid MSE: 1757.0716552734375
Completed epoch 4429. Valid MSE: 1710.4493408203125
Completed epoch 4430. Valid MSE: 1554.169921875
Completed epoch 4431. Valid MSE: 1540.675048828125
Completed epoch 4432. Valid MSE: 1542.17626953125
Completed epoch 4433. Valid MSE: 1501.481201171875
Completed epoch 4434. Valid MSE: 1583.2388916015625
Completed epoch 4435. Valid MSE: 1918.5587158203125
Completed epoch 4436. Valid MSE: 1861.27783203125
Completed epoch 4437. Valid MSE: 1873.9630126953125
Completed epoch 4438. Valid MSE: 1959.7652587890625
Completed epoch 4439. Valid MSE: 1979.35205078125
Completed epoch 4440. Valid MSE:

Completed epoch 4584. Valid MSE: 2003.302978515625
Completed epoch 4585. Valid MSE: 2014.3397216796875
Completed epoch 4586. Valid MSE: 2038.10107421875
Completed epoch 4587. Valid MSE: 2052.9169921875
Completed epoch 4588. Valid MSE: 2062.168701171875
Completed epoch 4589. Valid MSE: 2014.135009765625
Completed epoch 4590. Valid MSE: 1987.712890625
Completed epoch 4591. Valid MSE: 1972.9776611328125
Completed epoch 4592. Valid MSE: 1982.655029296875
Completed epoch 4593. Valid MSE: 1973.0093994140625
Completed epoch 4594. Valid MSE: 1915.175537109375
Completed epoch 4595. Valid MSE: 1804.545166015625
Completed epoch 4596. Valid MSE: 1666.7723388671875
Completed epoch 4597. Valid MSE: 1570.4166259765625
Completed epoch 4598. Valid MSE: 1496.2276611328125
Completed epoch 4599. Valid MSE: 1397.2535400390625
Completed epoch 4600. Valid MSE: 1333.20849609375
Completed epoch 4601. Valid MSE: 1315.3917236328125
Completed epoch 4602. Valid MSE: 1254.9388427734375
Completed epoch 4603. Valid M

Completed epoch 4746. Valid MSE: 1626.6470947265625
Completed epoch 4747. Valid MSE: 1628.2667236328125
Completed epoch 4748. Valid MSE: 1631.7940673828125
Completed epoch 4749. Valid MSE: 1634.8450927734375
Completed epoch 4750. Valid MSE: 1644.40380859375
Completed epoch 4751. Valid MSE: 1654.5677490234375
Completed epoch 4752. Valid MSE: 1664.2200927734375
Completed epoch 4753. Valid MSE: 1660.7269287109375
Completed epoch 4754. Valid MSE: 1654.900390625
Completed epoch 4755. Valid MSE: 1656.402587890625
Completed epoch 4756. Valid MSE: 1667.998779296875
Completed epoch 4757. Valid MSE: 1672.341796875
Completed epoch 4758. Valid MSE: 1654.9373779296875
Completed epoch 4759. Valid MSE: 1639.38720703125
Completed epoch 4760. Valid MSE: 1649.78662109375
Completed epoch 4761. Valid MSE: 1633.7130126953125
Completed epoch 4762. Valid MSE: 1606.9141845703125
Completed epoch 4763. Valid MSE: 1600.58056640625
Completed epoch 4764. Valid MSE: 1579.3309326171875
Completed epoch 4765. Valid MS

In [36]:
def eval_model(model, test):
    batch_size = 128
    test_iterator = get_stock_iterator(test, batch_size, shuffle=False)
        
    criterion = nn.MSELoss()
    
    model.eval()
    losses = []
    for batch_series, batch_aux, batch_labels in test_iterator:
        batch_aux = torch.reshape(batch_aux, (-1,1))
        outputs = model(batch_series.float().cuda(), batch_aux.float().cuda())
        batch_labels = torch.reshape(batch_labels, (-1,1))
        loss = criterion(outputs, batch_labels.float().cuda())
        losses.append(loss.item())
    return np.mean(losses)

In [41]:
eval_model(updated_model, test_stock)

2511.2421875

In [1]:
print('hi')

hi
