In [1]:
# declare a list tasks whose products you want to use as inputs
upstream = ['combine_fred_yahoo']

In [2]:
# Parameters
upstream = {
    "combine_fred_yahoo": {
        "nb": "/home/vgaurav/market_watch/output/notebooks/combine_fred_yahoo.ipynb",
        "data": "/home/vgaurav/market_watch/output/data/raw/fred_yahoo.xlsx",
    }
}
product = {"nb": "/home/vgaurav/market_watch/output/notebooks/train_model.ipynb"}


In [3]:
import os
import ptan
import pathlib
import gym.wrappers
import numpy as np
import pandas as pd

import torch
import torch.optim as optim

from ignite.engine import Engine
from ignite.contrib.handlers import tensorboard_logger as tb_logger

from src.models.lib import environ, data, models, common, validation

SAVES_DIR = pathlib.Path("saves")
BATCH_SIZE = 32
BARS_COUNT = 10

EPS_START = 1.0
EPS_FINAL = 0.1
EPS_STEPS = 1000000

GAMMA = 0.99

REPLAY_SIZE = 100000
REPLAY_INITIAL = 1000
REWARD_STEPS = 2
LEARNING_RATE = 0.0001
STATES_TO_EVALUATE = 1000

In [4]:
cuda = torch.cuda.is_available()
print(f"GPU support is {'enabled' if cuda else 'not available'}")
cuda

GPU support is enabled


True

In [5]:
run = 'test'
# args = parser.parse_args()
device = torch.device("cuda" if cuda else "cpu")

saves_path = SAVES_DIR / f"conv-{run}"
saves_path.mkdir(parents=True, exist_ok=True)


features = [
    'High', 'Close', 'Low',
    'Open' # open is broken. has shape (558, 415) vs (692, 403)
    ]

stock_data_path = upstream['combine_fred_yahoo']['data']
dfs = pd.read_excel(stock_data_path, sheet_name=features)

cols = None
for feature in features:
    if cols is None:
        cols = dfs[feature].columns
    else:
        cols = np.intersect1d(cols, dfs[feature].columns)

cols = [c for c in cols if 'date' not in c.lower()]
data = np.array([
    dfs[feature][cols] for feature in features
]).astype(np.float32)

env = environ.MarketWatchStocksEnv(data, bars_count=BARS_COUNT, state_1d=True)
env_tst = environ.MarketWatchStocksEnv(data, bars_count=BARS_COUNT, state_1d=True)


env = gym.wrappers.TimeLimit(env, max_episode_steps=1000)
env_val = env_tst

  deprecation(
  deprecation(


In [6]:
net = models.DQNConv1DMarketWatch(env.observation_space.shape, env.action_space.n).to(device)
tgt_net = ptan.agent.TargetNet(net)

selector = ptan.actions.EpsilonGreedyActionSelector(EPS_START)
eps_tracker = ptan.actions.EpsilonTracker(
    selector, EPS_START, EPS_FINAL, EPS_STEPS)
agent = ptan.agent.DQNAgent(net, selector, device=device, preprocessor=lambda x: common.state_preprocessor(x, device=device))
exp_source = ptan.experience.ExperienceSourceFirstLast(
    env, agent, GAMMA, steps_count=REWARD_STEPS)
buffer = ptan.experience.ExperienceReplayBuffer(
    exp_source, REPLAY_SIZE)
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)


def process_batch(engine, batch):
    optimizer.zero_grad()
    loss_v = common.calc_loss(
        batch, net, tgt_net.target_model,
        gamma=GAMMA ** REWARD_STEPS, device=device)
    loss_v.backward()
    optimizer.step()
    eps_tracker.frame(engine.state.iteration)

    if getattr(engine.state, "eval_states", None) is None:
        eval_states = buffer.sample(STATES_TO_EVALUATE)
        eval_states = [np.array(transition.state, copy=False)
                       for transition in eval_states]
        engine.state.eval_states = np.array(eval_states, copy=False)

    return {
        "loss": loss_v.item(),
        "epsilon": selector.epsilon,
    }


engine = Engine(process_batch)
tb = common.setup_ignite(engine, exp_source, f"conv-{run}",
                         extra_metrics=('values_mean',))


@engine.on(ptan.ignite.PeriodEvents.ITERS_10_COMPLETED)
def sync_eval(engine: Engine):
    tgt_net.sync()

    mean_val = common.calc_values_of_states(
        engine.state.eval_states, net, device=device)
    engine.state.metrics["values_mean"] = mean_val
    is_first = False
    if getattr(engine.state, "best_mean_val", None) is None:
        engine.state.best_mean_val = mean_val
        is_first = True

    if engine.state.best_mean_val < mean_val or is_first:
        print("%d: Best mean value updated %.3f -> %.3f" % (
            engine.state.iteration, engine.state.best_mean_val,
            mean_val))
        path = saves_path / ("mean_value_%.3f.data" % mean_val)
        torch.save(net.state_dict(), path)
        engine.state.best_mean_val = mean_val
    else:
        print(f'mean_val ${mean_val}, less than best {engine.state.best_mean_val}')

def validate(engine: Engine):
    res = validation.validation_run(env_tst, net, device=device)
    print("%d: tst: %s" % (engine.state.iteration, res))
    for key, val in res.items():
        engine.state.metrics[key + "_tst"] = val
    res = validation.validation_run(env_val, net, device=device)
    print("%d: val: %s" % (engine.state.iteration, res))
    for key, val in res.items():
        engine.state.metrics[key + "_val"] = val
    val_reward = res['episode_reward']
    if getattr(engine.state, "best_val_reward", None) is None:
        engine.state.best_val_reward = val_reward
    if engine.state.best_val_reward < val_reward:
        print("Best validation reward updated: %.3f -> %.3f, model saved" % (
            engine.state.best_val_reward, val_reward
        ))
        engine.state.best_val_reward = val_reward
        path = saves_path / ("val_reward-%.3f.data" % val_reward)
        torch.save(net.state_dict(), path)


event = ptan.ignite.PeriodEvents.ITERS_100_COMPLETED
tst_metrics = [m + "_tst" for m in validation.METRICS]
tst_handler = tb_logger.OutputHandler(
    tag="test", metric_names=tst_metrics)
tb.attach(engine, log_handler=tst_handler, event_name=event)

val_metrics = [m + "_val" for m in validation.METRICS]
val_handler = tb_logger.OutputHandler(
    tag="validation", metric_names=val_metrics)
tb.attach(engine, log_handler=val_handler, event_name=event)

engine.run(common.batch_generator(buffer, REPLAY_INITIAL, BATCH_SIZE))

Episode 100: reward=-4, steps=8, speed=0.0 f/s, elapsed=0:00:02


10: Best mean value updated -0.062 -> -0.062


mean_val $-0.10762025555595756, less than best -0.06197773676831275


mean_val $-0.1520497816381976, less than best -0.06197773676831275


mean_val $-0.1957495182286948, less than best -0.06197773676831275


mean_val $-0.23817035928368568, less than best -0.06197773676831275


mean_val $-0.27940057567320764, less than best -0.06197773676831275


mean_val $-0.3318737333174795, less than best -0.06197773676831275


mean_val $-0.38813579617999494, less than best -0.06197773676831275


mean_val $-0.4255446675233543, less than best -0.06197773676831275


mean_val $-0.4694188041612506, less than best -0.06197773676831275


mean_val $-0.5165102048777044, less than best -0.06197773676831275


mean_val $-0.5551581759937108, less than best -0.06197773676831275


mean_val $-0.594031204469502, less than best -0.06197773676831275


mean_val $-0.6389567963778973, less than best -0.06197773676831275


mean_val $-0.6728829219937325, less than best -0.06197773676831275


mean_val $-0.7064649816602468, less than best -0.06197773676831275


mean_val $-0.7402018303982913, less than best -0.06197773676831275


mean_val $-0.7743290541693568, less than best -0.06197773676831275


mean_val $-0.8103747190907598, less than best -0.06197773676831275


mean_val $-0.8568783137015998, less than best -0.06197773676831275


mean_val $-0.9102857243269682, less than best -0.06197773676831275


mean_val $-0.9213538193143904, less than best -0.06197773676831275


Episode 200: reward=-2, steps=12, speed=4.4 f/s, elapsed=0:00:55


mean_val $-0.9512658584862947, less than best -0.06197773676831275


mean_val $-0.9958063955418766, less than best -0.06197773676831275


mean_val $-1.0365040972828865, less than best -0.06197773676831275


mean_val $-1.054327073507011, less than best -0.06197773676831275


mean_val $-1.0643138522282243, less than best -0.06197773676831275


mean_val $-1.0614749658852816, less than best -0.06197773676831275


mean_val $-1.053081982769072, less than best -0.06197773676831275


mean_val $-1.0648898351937532, less than best -0.06197773676831275


mean_val $-1.0799097158014774, less than best -0.06197773676831275


mean_val $-1.0741623472422361, less than best -0.06197773676831275


mean_val $-1.0539975063875318, less than best -0.06197773676831275


mean_val $-1.0445513194426894, less than best -0.06197773676831275


mean_val $-1.031144494190812, less than best -0.06197773676831275


mean_val $-0.9897136585786939, less than best -0.06197773676831275


mean_val $-0.9599058963358402, less than best -0.06197773676831275


mean_val $-0.9644982344470918, less than best -0.06197773676831275


mean_val $-0.9626489486545324, less than best -0.06197773676831275


mean_val $-0.9549096487462521, less than best -0.06197773676831275


mean_val $-0.9419256481342018, less than best -0.06197773676831275


mean_val $-0.9292441117577255, less than best -0.06197773676831275


mean_val $-0.912672049831599, less than best -0.06197773676831275


mean_val $-0.8983630365692079, less than best -0.06197773676831275


mean_val $-0.9140937807969749, less than best -0.06197773676831275


mean_val $-0.9138155076652765, less than best -0.06197773676831275


mean_val $-0.8981135110370815, less than best -0.06197773676831275


mean_val $-0.8735243920236826, less than best -0.06197773676831275


mean_val $-0.8726231697946787, less than best -0.06197773676831275


mean_val $-0.8462350438348949, less than best -0.06197773676831275


mean_val $-0.8225297895260155, less than best -0.06197773676831275


mean_val $-0.8164832219481468, less than best -0.06197773676831275


mean_val $-0.8089371458627284, less than best -0.06197773676831275


mean_val $-0.7890307675115764, less than best -0.06197773676831275


mean_val $-0.7840071767568588, less than best -0.06197773676831275


mean_val $-0.7888493184000254, less than best -0.06197773676831275


mean_val $-0.7789020361378789, less than best -0.06197773676831275


mean_val $-0.7786035449244082, less than best -0.06197773676831275


mean_val $-0.7513212906196713, less than best -0.06197773676831275


mean_val $-0.7369994926266372, less than best -0.06197773676831275


mean_val $-0.7247247258201241, less than best -0.06197773676831275


mean_val $-0.7410803423263133, less than best -0.06197773676831275


mean_val $-0.7348032100126147, less than best -0.06197773676831275


mean_val $-0.7391571491025388, less than best -0.06197773676831275


mean_val $-0.7362648365087807, less than best -0.06197773676831275


mean_val $-0.7154776542447507, less than best -0.06197773676831275


mean_val $-0.7092593987472355, less than best -0.06197773676831275


mean_val $-0.7079924964345992, less than best -0.06197773676831275


mean_val $-0.7061227867379785, less than best -0.06197773676831275


mean_val $-0.6890785819850862, less than best -0.06197773676831275


mean_val $-0.6773346434347332, less than best -0.06197773676831275


mean_val $-0.6675650058314204, less than best -0.06197773676831275


mean_val $-0.6478525637649, less than best -0.06197773676831275


mean_val $-0.6139469046611339, less than best -0.06197773676831275


mean_val $-0.595665839035064, less than best -0.06197773676831275


mean_val $-0.5898755246307701, less than best -0.06197773676831275


Episode 300: reward=-6, steps=9, speed=4.4 f/s, elapsed=0:02:55


mean_val $-0.5905420950148255, less than best -0.06197773676831275


mean_val $-0.576495393877849, less than best -0.06197773676831275


mean_val $-0.5632956847548485, less than best -0.06197773676831275


mean_val $-0.5472545572556555, less than best -0.06197773676831275


mean_val $-0.5462254805024713, less than best -0.06197773676831275


mean_val $-0.5327047514729202, less than best -0.06197773676831275


mean_val $-0.5184470107778907, less than best -0.06197773676831275


mean_val $-0.5265924646519125, less than best -0.06197773676831275


mean_val $-0.5499458655249327, less than best -0.06197773676831275


mean_val $-0.5392888321075588, less than best -0.06197773676831275


mean_val $-0.5164421242661774, less than best -0.06197773676831275


mean_val $-0.5031698786187917, less than best -0.06197773676831275


mean_val $-0.5056373388506472, less than best -0.06197773676831275


mean_val $-0.5057244130875915, less than best -0.06197773676831275


mean_val $-0.48804541816934943, less than best -0.06197773676831275


mean_val $-0.49077134067192674, less than best -0.06197773676831275


mean_val $-0.5039822605904192, less than best -0.06197773676831275


mean_val $-0.5145560712553561, less than best -0.06197773676831275


mean_val $-0.5211662848014385, less than best -0.06197773676831275


mean_val $-0.5291597684845328, less than best -0.06197773676831275


mean_val $-0.540014500496909, less than best -0.06197773676831275


mean_val $-0.5463650734163821, less than best -0.06197773676831275


mean_val $-0.5202517749276012, less than best -0.06197773676831275


mean_val $-0.4928248368669301, less than best -0.06197773676831275


mean_val $-0.4707974884659052, less than best -0.06197773676831275


mean_val $-0.47665486903861165, less than best -0.06197773676831275


mean_val $-0.4668595613911748, less than best -0.06197773676831275


mean_val $-0.43184090475551784, less than best -0.06197773676831275


mean_val $-0.42086603469215333, less than best -0.06197773676831275


mean_val $-0.43352222302928567, less than best -0.06197773676831275


mean_val $-0.440894415602088, less than best -0.06197773676831275


mean_val $-0.4711053187493235, less than best -0.06197773676831275


mean_val $-0.4986420455388725, less than best -0.06197773676831275


mean_val $-0.4997203720267862, less than best -0.06197773676831275


mean_val $-0.5014599554706365, less than best -0.06197773676831275


mean_val $-0.4985346272587776, less than best -0.06197773676831275


mean_val $-0.48677383945323527, less than best -0.06197773676831275


mean_val $-0.48754516849294305, less than best -0.06197773676831275


mean_val $-0.49042097106575966, less than best -0.06197773676831275


mean_val $-0.5083414446562529, less than best -0.06197773676831275


mean_val $-0.5128406779840589, less than best -0.06197773676831275


mean_val $-0.5291198715567589, less than best -0.06197773676831275


mean_val $-0.5493549280799925, less than best -0.06197773676831275


mean_val $-0.5616327775642276, less than best -0.06197773676831275


mean_val $-0.5769732627086341, less than best -0.06197773676831275


mean_val $-0.5699534108862281, less than best -0.06197773676831275


mean_val $-0.561770292930305, less than best -0.06197773676831275


mean_val $-0.5672024036757648, less than best -0.06197773676831275


mean_val $-0.5693754181265831, less than best -0.06197773676831275


mean_val $-0.5690118186175823, less than best -0.06197773676831275


mean_val $-0.5936087383888662, less than best -0.06197773676831275


mean_val $-0.6174152768217027, less than best -0.06197773676831275


mean_val $-0.619238103274256, less than best -0.06197773676831275


mean_val $-0.6187323434278369, less than best -0.06197773676831275


mean_val $-0.5924100927077234, less than best -0.06197773676831275


mean_val $-0.5482840840704739, less than best -0.06197773676831275


mean_val $-0.5224225362762809, less than best -0.06197773676831275


mean_val $-0.5194428246468306, less than best -0.06197773676831275


mean_val $-0.49951288430020213, less than best -0.06197773676831275


mean_val $-0.4865885558538139, less than best -0.06197773676831275


mean_val $-0.5084908157587051, less than best -0.06197773676831275


Episode 400: reward=-4, steps=16, speed=4.4 f/s, elapsed=0:05:12


mean_val $-0.5275624562054873, less than best -0.06197773676831275


mean_val $-0.5259444797411561, less than best -0.06197773676831275


mean_val $-0.5249588028527796, less than best -0.06197773676831275


mean_val $-0.49665688211098313, less than best -0.06197773676831275


mean_val $-0.4862912120297551, less than best -0.06197773676831275


mean_val $-0.48234792333096266, less than best -0.06197773676831275


mean_val $-0.4843769818544388, less than best -0.06197773676831275


mean_val $-0.4859811053611338, less than best -0.06197773676831275


mean_val $-0.4982148688286543, less than best -0.06197773676831275


mean_val $-0.5053602079860866, less than best -0.06197773676831275


mean_val $-0.47256032517179847, less than best -0.06197773676831275


mean_val $-0.47039866354316473, less than best -0.06197773676831275


mean_val $-0.48232091683894396, less than best -0.06197773676831275


mean_val $-0.48968149768188596, less than best -0.06197773676831275


mean_val $-0.475736279040575, less than best -0.06197773676831275


mean_val $-0.47531788190826774, less than best -0.06197773676831275


mean_val $-0.4819168224930763, less than best -0.06197773676831275


mean_val $-0.491887750569731, less than best -0.06197773676831275


mean_val $-0.5112470807507634, less than best -0.06197773676831275


mean_val $-0.5222430503927171, less than best -0.06197773676831275


mean_val $-0.5291584888473153, less than best -0.06197773676831275


mean_val $-0.5438268659636378, less than best -0.06197773676831275


mean_val $-0.548888027202338, less than best -0.06197773676831275


mean_val $-0.5425974447280169, less than best -0.06197773676831275


mean_val $-0.5263679656200111, less than best -0.06197773676831275


mean_val $-0.5129117397591472, less than best -0.06197773676831275


mean_val $-0.5178625117987394, less than best -0.06197773676831275


mean_val $-0.5317338649183512, less than best -0.06197773676831275


mean_val $-0.5479961237870157, less than best -0.06197773676831275


mean_val $-0.5562099060043693, less than best -0.06197773676831275


mean_val $-0.5347268916666508, less than best -0.06197773676831275


mean_val $-0.52580535877496, less than best -0.06197773676831275


mean_val $-0.527734749019146, less than best -0.06197773676831275


mean_val $-0.533217357005924, less than best -0.06197773676831275


mean_val $-0.5014846106059849, less than best -0.06197773676831275


mean_val $-0.49462484614923596, less than best -0.06197773676831275


mean_val $-0.494599437341094, less than best -0.06197773676831275


mean_val $-0.49113920889794827, less than best -0.06197773676831275


mean_val $-0.47497384482994676, less than best -0.06197773676831275


mean_val $-0.48349957074970007, less than best -0.06197773676831275


mean_val $-0.48856264911592007, less than best -0.06197773676831275
