# Detecting EMA cross traps by using NEAT

### Import Library

In [1]:
import numpy as np
import pandas as pd
import numpy as np
import pandas_ta as ta
import seaborn as sns
import os

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [12, 6]
plt.rcParams['figure.dpi'] = 120
import warnings
warnings.filterwarnings('ignore')

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import neat

### Load Price Data

In [3]:
import os
from pathlib import Path
notebook_path = os.getcwd()
current_dir = Path(notebook_path)
csv_file = str(current_dir) + '/VN30F1M_5minutes.csv'
is_file = os.path.isfile(csv_file)
if is_file:
    dataset = pd.read_csv(csv_file, index_col='Date', parse_dates=True)
else:
    print(csv_file)
    print('remote')
    dataset = pd.read_csv("https://raw.githubusercontent.com/zuongthaotn/vn-stock-data/main/VN30ps/VN30F1M_5minutes.csv", index_col='Date', parse_dates=True)

In [134]:
data = dataset.copy()

In [24]:
data = data[data.index > '2020-11-01 00:00:00']

In [25]:
data["ema_fast"] = ta.ema(data["Close"], length=20)
data["ema_low"] = ta.ema(data["Close"], length=250)
data["ema_cross"] = ((data["ema_fast"] > data["ema_low"]) & (data["ema_fast"].shift(1) <= data["ema_low"].shift(1)) | (data["ema_fast"] < data["ema_low"]) & (data["ema_fast"].shift(1) >= data["ema_low"].shift(1)))

## Calculate some common features

In [26]:
data["ATR"] = ta.atr(data["High"], data["Low"], data["Close"], length=14)  # Volatility
data["RSI"] = ta.rsi(data["Close"], length=14)  # Momentum indicator

## TRAP labeling

In [27]:
def is_trap(r):
    trap = ''
    if r['ema_cross'] == True:
        if r['ema_fast'] > r['ema_low']:
            # Cross up
            if r['min_low_1dlater'] < r['Close'] - 3.5:
                trap = 1
            else:
                trap = 0
        else:
            # Cross down
            if r['max_high_1dlater'] > r['Close'] + 3.5:
                trap = 1
            else:
                trap = 0
    return trap

In [63]:
data['max_high_1dlater'] = data['High'].shift(-51).rolling(51).max()
data['min_low_1dlater'] = data['Low'].shift(-51).rolling(51).min()
data['trap'] = data.apply(lambda r: is_trap(r), axis=1)

In [64]:
# cross_data = data[data.ema_cross == True]
# len(cross_data[cross_data.trap == 0]) / len(cross_data['trap'])

0.33240997229916897

## 1. price_5 & price_20 & price_250 & rsi_5 & rsi & price_diff_ema & price_diff_ema_5 & atr

In [78]:
df1 = data[['Close', 'RSI', 'ATR', 'ema_cross', 'ema_fast', 'ema_low', 'trap']].copy()
df1['Close_s5'] = df1['Close'].shift(5)
df1['Close_s20'] = df1['Close'].shift(20)
df1['Close_s250'] = df1['Close'].shift(250)
df1['RSI_s5'] = df1['RSI'].shift(5)
df1['diff_ema'] = df1['Close'] - df1['ema_fast']
df1['diff_ema_s5'] = df1['diff_ema'].shift(5)
df1 = df1[(df1.ema_cross == True) & (df1.ema_fast > df1.ema_low)]
df1.dropna(inplace=True)

In [79]:
len(df1)

181

In [80]:
X = df1[["Close", "ATR", "RSI", "Close_s5", "Close_s20", "Close_s250", "RSI_s5", "diff_ema", "diff_ema_s5", "trap"]]

# Train-Test Split
X_train, X_test = train_test_split(X, test_size=0.25, random_state=42)

In [81]:
len(X_train)

135

In [82]:
len(X_test)

46

In [106]:
def eval_genomes(genomes, config):
    for genome_id, genome in genomes:
        genome.fitness = 4.0
        net = neat.nn.FeedForwardNetwork.create(genome, config)
        for move_index, row in X_train.iterrows():
            inputs = [row['Close'], row['ATR'], row['RSI'], row['Close_s5'], row['Close_s20'], row["Close_s250"], row["RSI_s5"], row["diff_ema"], row["diff_ema_s5"]]
            expected_output = row['trap']
            output = net.activate(inputs)
            genome.fitness -= (output[0] - expected_output) ** 2


def run(config_file):
    # Load configuration.
    config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
                         neat.DefaultSpeciesSet, neat.DefaultStagnation,
                         config_file)

    # Create the population, which is the top-level object for a NEAT run.
    p = neat.Population(config)

    # Add a stdout reporter to show progress in the terminal.
    # p.add_reporter(neat.StdOutReporter(True))
    # stats = neat.StatisticsReporter()
    # p.add_reporter(stats)

    # Run for up to 100 generations.
    winner = p.run(eval_genomes, 100)

    # Display the winning genome.
    print('\nBest genome:\n{!s}'.format(winner))
    return neat.nn.FeedForwardNetwork.create(winner, config)

In [107]:
%%time
config_path = os.path.join(current_dir, 'config-feedforward.cfg')
best_brain = run(config_path)


Best genome:
Key: 9214
Fitness: -24.93231142822125
Nodes:
	0 DefaultNodeGene(key=0, bias=-0.16333650222821477, response=1.0, activation=sigmoid, aggregation=sum)
	1489 DefaultNodeGene(key=1489, bias=2.838993402031513, response=1.0, activation=sigmoid, aggregation=sum)
	1728 DefaultNodeGene(key=1728, bias=2.2412179950954134, response=1.0, activation=sigmoid, aggregation=sum)
Connections:
	DefaultConnectionGene(key=(-7, 0), weight=2.4125998168672504, enabled=False)
	DefaultConnectionGene(key=(-4, 0), weight=1.4934156255455695, enabled=False)
	DefaultConnectionGene(key=(-3, 0), weight=-5.134599599180526, enabled=False)
	DefaultConnectionGene(key=(-2, 0), weight=0.12724749640024224, enabled=True)
	DefaultConnectionGene(key=(-1, 0), weight=0.5658190896488338, enabled=False)
	DefaultConnectionGene(key=(1728, 1489), weight=0.8379188857897881, enabled=True)
CPU times: user 1min 50s, sys: 20.9 ms, total: 1min 50s
Wall time: 1min 50s


In [108]:
best_brain

<neat.nn.feed_forward.FeedForwardNetwork at 0x7b30e83166d0>

In [121]:
# Show output of the most fit genome against training data.
outputs = []
for i, row in X_test.iterrows():
    inputs = [row['Close'], row['ATR'], row['RSI'], row['Close_s5'], row['Close_s20'], row["Close_s250"], row["RSI_s5"], row["diff_ema"], row["diff_ema_s5"]]
    expected_output = row['trap']
    output = best_brain.activate(inputs)
    outputs.append(round(output[0]))
    # print("input {!r}, expected output {!r}, got {!r}".format(inputs, expected_output, output))


Output:
input [1481.1, 3.6677949125676133, 52.673688732112076, 1479.0, 1484.9, 1496.3, 47.8618187282199, 1.8501635030086163, -0.46443081928418906], expected output 1, got [0.8200824344985203]
input [1559.0, 3.153000960658435, 73.61259615664022, 1548.2, 1542.9, 1514.9, 54.78525321641772, 10.342343798269667, 1.7709652355811158], expected output 1, got [0.7666292435177092]
input [1266.6, 1.7078263401697684, 67.6372430336863, 1263.0, 1260.5, 1257.8, 62.33627346434971, 3.1967907536827624, 2.0407728099348788], expected output 1, got [0.5670685069671452]
input [1250.0, 2.1878890987115573, 70.94663545981838, 1249.9, 1232.0, 1237.0, 76.13187489319358, 4.797286339511402, 8.206518332459382], expected output 1, got [0.6399931789837626]
input [1185.6, 1.4570417160990927, 65.26396267570391, 1184.0, 1181.3, 1185.8, 61.0604110986781, 2.9834419410678947, 3.076513021069104], expected output 0, got [0.527557555764704]
input [1373.0, 2.9172516584134156, 57.56048908781298, 1366.3, 1378.3, 1358.1, 41.28927

In [124]:
expected_outputs = X_test['trap'].to_list()
# Evaluate Performance
print("Accuracy:", accuracy_score(expected_outputs, outputs))

Accuracy: 0.7608695652173914


In [125]:
X_test

Unnamed: 0_level_0,Close,ATR,RSI,Close_s5,Close_s20,Close_s250,RSI_s5,diff_ema,diff_ema_s5,trap
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2021-06-21 14:00:00,1481.1,3.667795,52.673689,1479.0,1484.9,1496.3,47.861819,1.850164,-0.464431,1
2021-11-30 09:00:00,1559.0,3.153001,73.612596,1548.2,1542.9,1514.9,54.785253,10.342344,1.770965,1
2024-04-12 09:20:00,1266.6,1.707826,67.637243,1263.0,1260.5,1257.8,62.336273,3.196791,2.040773,1
2022-07-04 09:55:00,1250.0,2.187889,70.946635,1249.9,1232.0,1237.0,76.131875,4.797286,8.206518,1
2024-01-26 09:40:00,1185.6,1.457042,65.263963,1184.0,1181.3,1185.8,61.060411,2.983442,3.076513,0
2021-05-18 13:50:00,1373.0,2.917252,57.560489,1366.3,1378.3,1358.1,41.289271,3.843971,-2.938293,1
2021-07-28 10:30:00,1416.3,2.414181,60.912946,1411.5,1409.0,1411.7,50.871388,3.561756,0.801948,1
2022-05-18 09:50:00,1284.5,5.628941,68.613205,1273.5,1270.8,1332.6,61.572338,12.132353,2.469311,1
2023-03-08 14:20:00,1032.0,3.222577,72.678563,1027.3,1017.3,1030.3,66.645504,7.567975,5.88758,0
2023-04-12 09:30:00,1076.2,1.340023,64.254888,1075.9,1068.8,1079.5,64.051463,2.115607,2.907097,1
