In [1]:
# import packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import StandardScaler

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings 

warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)

device = ("cpu")

In [2]:
# define a Dataset class
class MyDataset(Dataset):
    def __init__(self, features, target):
        self.y = target
        self.X = features
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx, :], self.y[idx]

In [3]:
df_train = pd.read_csv('../kaggle/train_feats.csv')
df_train = df_train.fillna(0)

df_train.head()

Unnamed: 0,stock_id,date_id,seconds_in_bucket,imbalance_size,imbalance_buy_sell_flag,reference_price,matched_size,far_price,near_price,bid_price,bid_size,ask_price,ask_size,wap,target,session_label,volume,mid_price,liquidity_imbalance,matched_imbalance,size_imbalance,imbalance_intensity,matched_intensity,reference_price_far_price_imb,reference_price_near_price_imb,reference_price_ask_price_imb,reference_price_bid_price_imb,reference_price_wap_imb,far_price_near_price_imb,far_price_ask_price_imb,far_price_bid_price_imb,far_price_wap_imb,near_price_ask_price_imb,near_price_bid_price_imb,near_price_wap_imb,ask_price_bid_price_imb,ask_price_wap_imb,bid_price_wap_imb,price_spread,market_urgency,depth_pressure,price_pressure,imbalance_with_flag,imbalance_momentum,spread_intensity,matched_size_ret_1,matched_size_ret_2,matched_size_ret_3,matched_size_ret_4,matched_size_ret_5,matched_size_ret_6,imbalance_size_ret_1,imbalance_size_ret_2,imbalance_size_ret_3,imbalance_size_ret_4,imbalance_size_ret_5,imbalance_size_ret_6,bid_size_ret_1,bid_size_ret_2,bid_size_ret_3,bid_size_ret_4,bid_size_ret_5,bid_size_ret_6,ask_size_ret_1,ask_size_ret_2,ask_size_ret_3,ask_size_ret_4,ask_size_ret_5,ask_size_ret_6,reference_price_ret_1,reference_price_ret_2,reference_price_ret_3,reference_price_ret_4,reference_price_ret_5,reference_price_ret_6,ask_price_ret_1,ask_price_ret_2,ask_price_ret_3,ask_price_ret_4,ask_price_ret_5,ask_price_ret_6,bid_price_ret_1,bid_price_ret_2,bid_price_ret_3,bid_price_ret_4,bid_price_ret_5,bid_price_ret_6,market_urgency_ret_1,market_urgency_ret_2,market_urgency_ret_3,market_urgency_ret_4,market_urgency_ret_5,market_urgency_ret_6,imbalance_momentum_ret_1,imbalance_momentum_ret_2,imbalance_momentum_ret_3,imbalance_momentum_ret_4,imbalance_momentum_ret_5,imbalance_momentum_ret_6,size_imbalance_ret_1,size_imbalance_ret_2,size_imbalance_ret_3,size_imbalance_ret_4,size_imbalance_ret_5,size_imbalance_ret_6,imbalance_flag_diff,all_prices_mean,all_sizes_mean,all_prices_std,all_sizes_std,all_prices_skew,all_sizes_skew,all_prices_kurt,all_sizes_kurt,ask_price_bid_price_wap_imb2,ask_price_bid_price_reference_price_imb2,ask_price_wap_reference_price_imb2,bid_price_wap_reference_price_imb2,matched_size_bid_size_ask_size_imb2,matched_size_bid_size_imbalance_size_imb2,matched_size_ask_size_imbalance_size_imb2,bid_size_ask_size_imbalance_size_imb2,dow,seconds,minute,global_median_size,global_std_size,global_ptp_size,global_median_price,global_std_price,global_ptp_price
0,0,0,0,3180602.8,1,0.999812,13380277.0,0.0,0.0,0.999812,60651.5,1.000026,8493.03,1.0,-3.029704,1,69144.53,0.999919,0.75434,-0.61589,7.141326,45.99934,193.51172,0.0,0.0,-0.000107,0.0,-9.4e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000107,1.3e-05,-9.400884e-05,0.000214,0.000161,0.0,680.649,3180602.8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.999912,4157506.0,0.000117,6324881.0,0.042932,1.695159,0.0,2.775961,0.138298,-1927541000000.0,0.138298,-1693354000000.0,255.36841,3.269177,3.215423,59.81677,0,0,0,39479.49,113601.55,5898989.5,1.999734,0.003419,0.017414
1,1,0,0,166603.9,-1,0.999896,1642214.2,0.0,0.0,0.999896,3233.04,1.00066,20605.09,1.0,-5.519986,1,23838.13,1.000278,-0.72875,-0.815787,0.156905,6.988967,68.89023,0.0,0.0,-0.000382,0.0,-5.2e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000382,0.00033,-5.20027e-05,0.000764,-0.000557,0.0,127.285385,-166603.9,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.000113,458164.06,0.000368,792759.4,1.897537,1.94999,3.617292,3.819817,6.346154,-6881500000000.0,6.346154,936748700000.0,93.34587,9.032273,10.107002,8.404237,0,0,0,25830.525,64297.402,511677.72,1.999944,0.005564,0.026874
2,2,0,0,302879.88,-1,0.999561,1819368.0,0.0,0.0,0.999403,37956.0,1.000298,18995.0,1.0,-8.38995,1,56951.0,0.99985,0.332935,-0.714567,1.99821,5.318254,31.9462,0.0,0.0,-0.000369,7.9e-05,-0.00022,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000448,0.000149,-0.0002985891,0.000895,0.000298,0.0,271.07748,-302879.88,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.999815,544799.75,0.000409,859536.8,0.31125,1.869559,-2.954224,3.507308,0.499162,4.664557,0.678816,2.778481,93.95138,5.724241,5.341913,13.972041,0,0,0,25387.9,68032.62,1069837.6,2.000193,0.00524,0.033413
3,3,0,0,11917682.0,-1,1.000171,18389746.0,0.0,0.0,0.999999,2324.9,1.000214,479032.4,1.0,-4.0102,1,481357.3,1.000107,-0.99034,-0.213547,0.004853,24.758495,38.20394,0.0,0.0,-2.1e-05,8.6e-05,8.5e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000107,0.000107,-5.000002e-07,0.000215,-0.000213,0.0,2562.3018,-11917682.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.000096,7697196.5,0.000113,9008441.0,0.124239,0.424923,0.0,-3.574718,214.0,0.25,0.251462,171.0,37.571705,0.54317,0.565807,23.995111,0,0,0,41327.156,93654.266,1928848.2,1.999947,0.002946,0.018551
4,4,0,0,447549.97,-1,0.999532,17860614.0,0.0,0.0,0.999394,16485.54,1.000016,434.1,1.0,-7.349849,1,16919.64,0.999705,0.948687,-0.951109,37.976364,26.451506,1055.6144,0.0,0.0,-0.000242,6.9e-05,-0.000234,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000311,8e-06,-0.0003030918,0.000622,0.00059,0.0,278.37607,-447549.97,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.999735,4581271.0,0.00032,8855318.0,-0.156643,1.996737,-5.078671,3.988887,0.026403,3.507246,0.034188,3.391304,1111.6841,40.395504,38.94531,26.855186,0,0,0,33488.0,80309.36,1604065.5,1.999842,0.003812,0.017379


In [4]:
df_valid = pd.read_csv('../kaggle/valid_feats.csv')
df_valid = df_valid.fillna(0)

df_valid.head()

Unnamed: 0,stock_id,date_id,seconds_in_bucket,imbalance_size,imbalance_buy_sell_flag,reference_price,matched_size,far_price,near_price,bid_price,bid_size,ask_price,ask_size,wap,target,session_label,volume,mid_price,liquidity_imbalance,matched_imbalance,size_imbalance,imbalance_intensity,matched_intensity,reference_price_far_price_imb,reference_price_near_price_imb,reference_price_ask_price_imb,reference_price_bid_price_imb,reference_price_wap_imb,far_price_near_price_imb,far_price_ask_price_imb,far_price_bid_price_imb,far_price_wap_imb,near_price_ask_price_imb,near_price_bid_price_imb,near_price_wap_imb,ask_price_bid_price_imb,ask_price_wap_imb,bid_price_wap_imb,price_spread,market_urgency,depth_pressure,price_pressure,imbalance_with_flag,imbalance_momentum,spread_intensity,matched_size_ret_1,matched_size_ret_2,matched_size_ret_3,matched_size_ret_4,matched_size_ret_5,matched_size_ret_6,imbalance_size_ret_1,imbalance_size_ret_2,imbalance_size_ret_3,imbalance_size_ret_4,imbalance_size_ret_5,imbalance_size_ret_6,bid_size_ret_1,bid_size_ret_2,bid_size_ret_3,bid_size_ret_4,bid_size_ret_5,bid_size_ret_6,ask_size_ret_1,ask_size_ret_2,ask_size_ret_3,ask_size_ret_4,ask_size_ret_5,ask_size_ret_6,reference_price_ret_1,reference_price_ret_2,reference_price_ret_3,reference_price_ret_4,reference_price_ret_5,reference_price_ret_6,ask_price_ret_1,ask_price_ret_2,ask_price_ret_3,ask_price_ret_4,ask_price_ret_5,ask_price_ret_6,bid_price_ret_1,bid_price_ret_2,bid_price_ret_3,bid_price_ret_4,bid_price_ret_5,bid_price_ret_6,market_urgency_ret_1,market_urgency_ret_2,market_urgency_ret_3,market_urgency_ret_4,market_urgency_ret_5,market_urgency_ret_6,imbalance_momentum_ret_1,imbalance_momentum_ret_2,imbalance_momentum_ret_3,imbalance_momentum_ret_4,imbalance_momentum_ret_5,imbalance_momentum_ret_6,size_imbalance_ret_1,size_imbalance_ret_2,size_imbalance_ret_3,size_imbalance_ret_4,size_imbalance_ret_5,size_imbalance_ret_6,imbalance_flag_diff,all_prices_mean,all_sizes_mean,all_prices_std,all_sizes_std,all_prices_skew,all_sizes_skew,all_prices_kurt,all_sizes_kurt,ask_price_bid_price_wap_imb2,ask_price_bid_price_reference_price_imb2,ask_price_wap_reference_price_imb2,bid_price_wap_reference_price_imb2,matched_size_bid_size_ask_size_imb2,matched_size_bid_size_imbalance_size_imb2,matched_size_ask_size_imbalance_size_imb2,bid_size_ask_size_imbalance_size_imb2,dow,seconds,minute,global_median_size,global_std_size,global_ptp_size,global_median_price,global_std_price,global_ptp_price
0,0,436,0,0.0,0,1.000268,12874820.0,0.0,0.0,0.999911,11182.0,1.000089,11184.0,1.0,6.630421,1,22366.0,1.0,-8.9e-05,-1.0,0.999821,0.0,575.6425,0.0,0.0,8.9e-05,0.000178,0.000134,0.0,0.0,0.0,0.0,0.0,0.0,0.0,8.9e-05,4.4e-05,-4.5e-05,0.000178,-1.591702e-08,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.000067,3224296.5,0.000152,6433684.5,0.759262,1.999996,0.0,3.999987,1.0,1.005618,2.011236,3.011236,6431818.0,1150.388,1150.182,0.000179,1,0,0,39479.49,113601.55,5898989.5,1.999734,0.003419,0.017414
1,1,436,0,1378667.5,1,0.999853,2806215.2,0.0,0.0,0.999853,801.2,1.000552,3006.6,1.0,-4.77016,1,3807.8,1.000203,-0.57918,-0.34112,0.26648,362.06403,736.96497,0.0,0.0,-0.000349,0.0,-7.4e-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000349,0.000276,-7.4e-05,0.000699,-0.0004048465,0.0,963.68854,1378667.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.000064,1047172.6,0.000332,1340310.4,1.756207,0.884692,3.037954,-1.132242,3.755102,6296032000000.0,3.755102,-1324058000000.0,1271.066,1.036057,1.037718,623.7693,1,0,0,25830.525,64297.402,511677.72,1.999944,0.005564,0.026874
2,2,436,0,0.0,0,0.999373,4873746.5,0.0,0.0,0.999373,8005.46,1.000132,1686.64,1.0,-4.829764,1,9692.1,0.999753,0.651956,-1.0,4.746395,0.0,502.8576,0.0,0.0,-0.00038,0.0,-0.000314,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.00038,6.6e-05,-0.000314,0.000759,0.0004948344,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.999719,1220859.6,0.000404,2435260.2,0.091753,1.999988,-5.472621,3.99996,0.210526,6836464000000.0,0.210526,-5647514000000.0,770.0395,607.8028,2888.6187,3.746395,1,0,0,25387.9,68032.62,1069837.6,2.000193,0.00524,0.033413
3,3,436,0,7863030.0,1,0.999576,46879784.0,0.0,0.0,0.999576,21946.1,1.000027,1397.2,1.0,1.900196,1,23343.3,0.999802,0.880291,-0.712728,15.7072,336.8431,2008.2759,0.0,0.0,-0.000226,0.0,-0.000212,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000226,1.3e-05,-0.000212,0.000451,0.0003970113,0.0,3546.2266,7863030.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.999795,13691540.0,0.000253,22432930.0,0.009867,1.844701,-5.94306,3.4059,0.063679,-2031123000000.0,0.063679,0.0,2280.309,4.975939,4.962933,381.58167,1,0,0,41327.156,93654.266,1928848.2,1.999947,0.002946,0.018551
4,4,436,0,4599490.0,-1,1.000425,15193377.0,0.0,0.0,0.999952,169.01,1.000721,2537.1,1.0,5.090237,1,2706.11,1.000336,-0.87509,-0.535238,0.066615,1699.6686,5614.471,0.0,0.0,-0.000148,0.000236,0.000212,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.000384,0.00036,-2.4e-05,0.000769,-0.0006729442,0.0,3537.008,-4599490.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.000274,4948893.5,0.000366,7165376.0,0.52038,1.509326,-2.931223,1.943178,15.020833,0.6257928,0.696471,8.854167,6414.807,2.303359,2.304545,1941.2072,1,0,0,33488.0,80309.36,1604065.5,1.999842,0.003812,0.017379


In [12]:
feat_cols = [col for col in df_train.columns if col not in ['target', 'date_id', 'dow']]

standarize_cols = [col for col in feat_cols if col not in ['stock_id', 'seconds_in_bucket', 'imbalance_buy_sell_flag', 'session_label', 'seconds', 'minute']]

df_train_feat = df_train[feat_cols]
scaler = StandardScaler()
df_train_feat[standarize_cols] = scaler.fit_transform(df_train_feat[standarize_cols])
train_dataset = MyDataset(torch.from_numpy(df_train_feat.values), torch.from_numpy(df_train['target'].values))

df_valid_feat = df_valid[feat_cols]
df_valid_feat[standarize_cols] = scaler.transform(df_valid_feat[standarize_cols])
valid_dataset = MyDataset(torch.from_numpy(df_valid_feat.values), torch.from_numpy(df_valid['target'].values))

print(train_dataset.X.shape, train_dataset.y.shape)
print(valid_dataset.X.shape, valid_dataset.y.shape)

torch.Size([4742893, 128]) torch.Size([4742893])
torch.Size([494999, 128]) torch.Size([494999])


In [18]:
# define a neural network model of 2 layers
# first layer is a non-linear layer with m neurons
# second layer is a linear layer with n neuron

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        x = self.flatten(x)
        output = self.linear_relu_stack(x)
        return output

model = NeuralNetwork()

In [19]:
from sklearn.metrics import mean_absolute_error

# define a function to train the model
def train(model, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.float())
        loss = criterion(output, target.float())
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

# define a function for validation
def validation(model, val_loader, criterion):
    model.eval()
    validation_loss = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data.float())
            validation_loss += criterion(output, target.float()).item()
            y_true.extend(target.tolist())
            y_pred.extend(output.tolist())
            
    validation_loss /= len(val_loader.dataset)
    print(f"\nValidation set: Average loss: {validation_loss:.6f}")
    print(f"\nValidation set: Mean Average Error: {mean_absolute_error(y_true, y_pred):.6f}\n")

# define a function to predict the target
def predict(model, pred_loader):
    model.eval()
    y_pred = []
    with torch.no_grad():
        for data, target in pred_loader:
            output = model(data.float())
            y_pred.extend(output.tolist())
            
    return y_pred

In [20]:
# print number of data points and number of features
print(f"Number of data points: {train_dataset.X.shape[0]}")
print(f"Number of features: {train_dataset.X.shape[1]}")

# train the model with train_dataset
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=2048, shuffle=True)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

for epoch in range(1, 3):
    train(model, train_loader, optimizer, criterion, epoch)
    validation(model, val_loader, criterion)


Number of data points: 4742893
Number of features: 128

Validation set: Average loss: 0.039049

Validation set: Mean Average Error: 5.960048


Validation set: Average loss: 0.039082

Validation set: Mean Average Error: 5.963173



In [21]:
# save the model
torch.save(model.state_dict(), "model_fcn_2layer.pth")