In [1]:
import enum
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from collections import namedtuple
import glob
import os
import csv
import random

In [2]:
class Actions(enum.IntEnum):
    Buy_100 = 0
    Buy_10 = 1
    Skip = 2
    Sell_10 = 3
    Sell_100 = 4

In [3]:
DEFAULT_BARS_COUNT = 10

class StocksEnv:
    
    def __init__(self, data):
        self.data = data
        self.bars_count = DEFAULT_BARS_COUNT
        self.linear_shape = self.bars_count*24+1
     
        
    def buyhold_reset(self, index):
        self.instrument = list(self.data.keys())[index]
        self.prices = self.data[self.instrument]
        self.buy_price = self.prices.price[9]
        self.sell_price = self.prices.price[-1]
        return self.buy_price, self.sell_price, self.instrument
    
    def linear_reset(self, index):
        self.fund = 100000
        self.hold = 0
        self.buy_price = []
        self.instrument = list(self.data.keys())[index]
        self.prices = self.data[self.instrument]
        self.offset = 9
        obs = self.linear_encode()
        return obs, self.instrument

    
    def linear_step(self, action):
        done = False
        close = self.prices.close[self.offset]
        price = self.prices.price[self.offset]
        self.affordable = self.fund // price
        if action == Actions.Buy_100 and self.affordable >= 100:
            self.hold += 100
            self.fund -= price * 100
            if int(price * 100 * 0.001425) == 0:
                tax = 1
            else:
                tax = int(price * 100 * 0.001425)
            self.fund -= tax
        elif action == Actions.Buy_10 and self.affordable >= 10:
            self.hold += 10
            self.fund -= price * 10
            if int(price * 10 * 0.001425) == 0:
                tax = 1
            else:
                tax = int(price * 10 * 0.001425)
            self.fund -= tax
        elif action == Actions.Sell_100 and self.hold >= 100:
            self.hold -= 100
            self.fund += price * 100
            if int(price * 100 * 0.001425) == 0:
                tax = 1
            else:
                tax = int(price * 100 * 0.001425)
            self.fund -= (tax + int(price*100*0.003))
        elif action == Actions.Sell_10 and self.hold >= 10:
            self.hold -= 10
            self.fund += price * 10
            if int(price * 10 * 0.001425) == 0:
                tax = 1
            else:
                tax = int(price * 10 * 0.001425)
            self.fund -= (tax + int(price*10*0.003))
        self.offset += 1
        prev_price = price
        price = self.prices.price[self.offset]
        if self.offset >= self.prices.close.shape[0]-1:
            done = True
             
        obs = self.linear_encode()
        
        return obs, done
    
    
    def linear_encode(self):
        obs = np.ndarray(shape=(self.linear_shape), dtype=np.float32)
        shift = 0
        for bar_idx in range(-self.bars_count+1, 1):
            obs[shift] = self.prices.open[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.high[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.low[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.close[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fa[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fafive[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.faten[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fatwenty[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fvolume[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fvolumefive[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fvolumeten[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fvolumetwenty[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fbb[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fbbten[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fbbtwenty[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fv[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fvfive[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fvten[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.frsv[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fk[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fd[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.frsi[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.fubb[self.offset + bar_idx]
            shift += 1
            obs[shift] = self.prices.flbb[self.offset + bar_idx]
            shift += 1
        obs[shift] = self.hold
        return obs
    

In [19]:
DEFAULT_test_data = "C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test"

Prices = namedtuple('prices', ('open', 'high', 'low', 'close', 'fa', 'fafive', 'faten', 'fatwenty', 'fvolume', 'fvolumefive', 'fvolumeten', 'fvolumetwenty', 'fbb', 'fbbten', 'fbbtwenty', 'fv', 'fvfive', 'fvten', 'frsv', 'fk', 'fd', 'frsi', 'fubb', 'flbb', 'price'))

def data_files(dir_name):
    result = []
    for path in glob.glob(os.path.join(dir_name, "*.csv")):
        result.append(path)
    return result

def read_csv(file_name, sep = ','):
    print("Reading", file_name)
    with open(file_name, 'rt', encoding='utf-8') as fd:
        reader = csv.reader(fd, delimiter=sep)
        h = next(reader)
        if 'Open' not in h and sep == ',':
            return read_csv(file_name, ';')
        indices = [h.index(s) for s in ('Open', 'High', 'Low', 'Close', 'A', 'A5', 'A10', 'A20', 'Capacity', 'Capacity5', 'Capacity10', 'Capacity20', 'BB', 'BB10', 'BB20', 'V', 'V5', 'V10', 'rsv', 'K', 'D', 'rsi', 'UBB', 'LBB')]
        o, h, l, c, a, afive, aten, atwenty, volume, volumefive, volumeten, volumetwenty, bb, bbten, bbtwenty, v, vfive, vten, rsv, k, d, rsi, ubb, lbb = [], [], [], [], [],[], [], [], [], [],[], [], [], [], [],[], [], [], [], [],[], [], [], []
        for row in reader:
            vals = list(map(float, [row[idx] for idx in indices]))
            po, ph, pl, pc,pa, pafive, paten, patwenty, pvolume, pvolumefive, pvolumeten, pvolumetwenty, pbb, pbbten, pbbtwenty, pv, pvfive, pvten, prsv, pk, pd, prsi, pubb, plbb  = vals
            o.append(po)
            c.append(pc)
            h.append(ph)
            l.append(pl)
            a.append(pa)
            afive.append(pafive)
            aten.append(paten)
            atwenty.append(patwenty)
            volume.append(pvolume)
            volumefive.append(pvolumefive)
            volumeten.append(pvolumeten)
            volumetwenty.append(pvolumetwenty)
            bb.append(pbb)
            bbten.append(pbbten)
            bbtwenty.append(pbbtwenty)
            v.append(pv)
            vfive.append(pvfive)
            vten.append(pvten)
            rsv.append(prsv)
            k.append(pk)
            d.append(pd)
            rsi.append(prsi)
            ubb.append(pubb)
            lbb.append(plbb)
    print("Read done")
    return Prices(open=np.array(o, dtype=np.float32),
                  high=np.array(h, dtype=np.float32),
                  low=np.array(l, dtype=np.float32),
                  close=np.array(c, dtype=np.float32),
                  fa=np.array(a, dtype=np.float32),
                  fafive=np.array(afive, dtype=np.float32),
                  faten=np.array(aten, dtype=np.float32),
                  fatwenty=np.array(atwenty, dtype=np.float32),
                  fvolume=np.array(volume, dtype=np.float32),
                  fvolumefive=np.array(volumefive, dtype=np.float32),
                  fvolumeten=np.array(volumeten, dtype=np.float32),
                  fvolumetwenty=np.array(volumetwenty, dtype=np.float32),
                  fbb=np.array(bb, dtype=np.float32),
                  fbbten=np.array(bbten, dtype=np.float32),
                  fbbtwenty=np.array(bbtwenty, dtype=np.float32),
                  fv=np.array(v, dtype=np.float32),
                  fvfive=np.array(vfive, dtype=np.float32),
                  fvten=np.array(vten, dtype=np.float32),
                  frsv=np.array(rsv, dtype=np.float32),
                  fk=np.array(k, dtype=np.float32),
                  fd=np.array(d, dtype=np.float32),
                  frsi=np.array(rsi, dtype=np.float32),
                  fubb=np.array(ubb, dtype=np.float32),
                  flbb=np.array(lbb, dtype=np.float32),
                  price=np.array(c, dtype=np.float32))

def prices_nlz(prices):
    o = (prices.open-prices.open.min()) / (prices.open.max()-prices.open.min())
    h = (prices.high-prices.high.min()) / (prices.high.max()-prices.high.min())
    l = (prices.low-prices.low.min()) / (prices.low.max()-prices.low.min())
    c = (prices.close-prices.close.min()) / (prices.close.max()-prices.close.min())
    a = (prices.fa-prices.fa.min()) / (prices.fa.max()-prices.fa.min())
    afive = (prices.fafive-prices.fafive.min()) / (prices.fafive.max()-prices.fafive.min())
    aten = (prices.faten-prices.faten.min()) / (prices.faten.max()-prices.faten.min())
    atwenty = (prices.fatwenty-prices.fatwenty.min()) / (prices.fatwenty.max()-prices.fatwenty.min())
    volume = (prices.fvolume-prices.fvolume.min()) / (prices.fvolume.max()-prices.fvolume.min())
    volumefive = (prices.fvolumefive-prices.fvolumefive.min()) / (prices.fvolumefive.max()-prices.fvolumefive.min())
    volumeten = (prices.fvolumeten-prices.fvolumeten.min()) / (prices.fvolumeten.max()-prices.fvolumeten.min())
    volumetwenty = (prices.fvolumetwenty-prices.fvolumetwenty.min()) / (prices.fvolumetwenty.max()-prices.fvolumetwenty.min())
    bb = (prices.fbb-prices.fbb.min()) / (prices.fbb.max()-prices.fbb.min())
    bbten = (prices.fbbten-prices.fbbten.min()) / (prices.fbbten.max()-prices.fbbten.min())
    bbtwenty = (prices.fbbtwenty-prices.fbbtwenty.min()) / (prices.fbbtwenty.max()-prices.fbbtwenty.min())
    v = (prices.fv-prices.fv.min()) / (prices.fv.max()-prices.fv.min())
    vfive = (prices.fvfive-prices.fvfive.min()) / (prices.fvfive.max()-prices.fvfive.min())
    vten = (prices.fvten-prices.fvten.min()) / (prices.fvten.max()-prices.fvten.min())
    rsv = (prices.frsv-prices.frsv.min()) / (prices.frsv.max()-prices.frsv.min())
    k = (prices.fk-prices.fk.min()) / (prices.fk.max()-prices.fk.min())
    d = (prices.fd-prices.fd.min()) / (prices.fd.max()-prices.fd.min())
    rsi = (prices.frsi-prices.frsi.min()) / (prices.frsi.max()-prices.frsi.min())
    ubb = (prices.fubb-prices.fubb.min()) / (prices.fubb.max()-prices.fubb.min())
    lbb = (prices.flbb-prices.flbb.min()) / (prices.flbb.max()-prices.flbb.min())
    p = prices.close
    return Prices(open=o, high=h, low=l, close=c, fa=a,fafive=afive, faten=aten, fatwenty=atwenty, fvolume=volume, fvolumefive=volumefive,fvolumeten=volumeten, fvolumetwenty=volumetwenty, fbb=bb, fbbten=bbten, fbbtwenty=bbtwenty,fv=v, fvfive=vfive, fvten=vten, frsv=rsv, fk=k,fd=d, frsi=rsi, fubb=ubb, flbb=lbb, price=p)

data_test = {file: prices_nlz(read_csv(file)) for file in data_files(DEFAULT_test_data)}

Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\1101test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\1216test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\1301test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\1326test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\1590test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2002test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2207test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2303test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2308test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2317test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2327test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2330test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2357test.csv
Read done
Reading C:/Users/yuanf/OneDrive/桌面/半導體股票資料/test\2379test.csv
Read done
Readin

In [20]:
class linear_DDQN(nn.Module):
    
    def __init__(self, in_n, out_n):
        super(linear_DDQN, self).__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(in_n, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, out_n)
        )

    def forward(self, x):
        return self.fc(x)

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Enviroment:
    
    def __init__(self):
        self.env_test = StocksEnv(data_test)
        actions_n = len(Actions)
        self.linear_q_network = linear_DDQN(self.env_test.linear_shape, actions_n).to(device)
        
    def decide_action(self, state, network):
        network.eval()
        with torch.no_grad():
            action = network(state).max(1)[1].view(1, 1)
        return action
    
    def linear_run(self):
        
        self.linear_q_network.load_state_dict(torch.load(
             r"C:\Users\yuanf\result\-20000.pth"))
        k=0
        for i in range(len(data_test)):
            obs, stock_num = self.env_test.linear_reset(i)
            
            while True:
                state = torch.tensor([obs]).to(device)
                action = self.decide_action(state, self.linear_q_network)
                obs_next, done = self.env_test.linear_step(action.item())
                obs = obs_next
                if done:
                    price = self.env_test.prices.price[self.env_test.offset]
                    hold = self.env_test.hold
                    if hold:
                        self.env_test.fund += price * hold
                        if int(price * hold * 0.001425) == 0:
                            tax = 1
                        else:
                            tax = int(price * hold * 0.001425)
                        self.env_test.fund -= (tax + int(price*hold*0.003))
                        profit = (self.env_test.fund - 100000) / 100000 * 100
                        k = k+profit
                        i=i+1
                        q=k/i
                    break   
            print( profit)
            
    
            
    def buy_hold(self):
        
        for i in range(len(data_test)):
            fund = 100000
            buy_price, sell_price, stock_num = self.env_test.buyhold_reset(i)
            hold = fund // buy_price
            if int(buy_price * hold * 0.001425) == 0:
                tax = 1
            else:
                tax = int(buy_price * hold * 0.001425)
            fund -= tax
            if int(sell_price * hold * 0.001425) == 0:
                tax = 1
            else:
                tax = int(sell_price * hold * 0.001425)
            fund -= tax
            fund += ((sell_price-buy_price) * hold - int(sell_price * hold * 0.003))
            profit = (fund - 100000) / 100000 * 100
            
            print( profit)

In [36]:
Stocks = Enviroment()

In [37]:
Stocks.linear_run()

11.827000244140624
-7.529002716064453
-5.490997329711914
-1.821003219604492
162.751
-2.858503829956055
-2.858503829956055
40.129000930786134
-3.0669999999999997
-6.651999999999999
-18.102
2.003
10.780000000000001
4.501
13.150996109008789
8.459999999999999
-4.773004394531251
14.03349754333496
5.415
5.415
100.167
215.2590004119873
128.67699893951416
-2.7659996299743654
13.932005184173585
40.0740050201416
40.42899987792969
28.310499868392945
8.557996643066407
15.082497894287108
18.111496871948244
22.824004737854004
20.55549698448181
20.014504280090332
13.223999504089356
3.15
-17.730999999999998
34.07
79.09899967956542
4.020000930786133
5.2449975738525385
63.184
21.363503898620607
21.363503898620607
17.798004486083986
67.74799999999999
20.143


In [38]:
Stocks.buy_hold()

13.770148181915284
3.337997688293457
-2.70549845123291
-3.5470069885253905
262.6145
41.19819540405273
-5.212
32.598749999999995
-8.7045
-10.4865
-20.828
1.7309999999999999
30.252000000000002
29.881999999999998
10.773196411132812
9.035
-9.937998229980469
54.94639870262146
4.345000000000001
39.024
314.00075
415.70330653381353
356.767
-4.5203043098449704
17.44175
59.42000961303711
48.711401824951174
87.5955028772354
11.562996948242187
19.798296367645264
20.648
43.324208702087404
40.89525
31.001
15.731503623962404
2.2175000000000002
-13.938
34.012
129.83599999999998
0.3211015396118164
4.774298755645752
54.152
25.915251897811892
81.321
-3.7977969207763675
159.673
24.101
