<a href="https://colab.research.google.com/github/passtock/stock-straregy/blob/main/stock200_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import yfinance as yf
import pandas as pd
import plotly.graph_objs as go
from plotly.subplots import make_subplots

def get_stock_price_data(stock_symbol, start_date, end_date):
    stock_data = yf.download(stock_symbol, start=start_date, end=end_date)
    return stock_data

def calculate_rsi(data, window=14):
    delta = data['Close'].diff(1)
    gain = (delta.where(delta > 0, 0)).rolling(window=window).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean()
    rs = gain / loss
    rsi = 100 - (100 / (1 + rs))
    return rsi

def calculate_moving_average(data, window):
    return data['Close'].rolling(window=window).mean()

def simulate_trading(stock_data, short_window, long_window, initial_cash):
    stock_data = stock_data.copy()  # 원본 데이터 보호
    stock_data['3_day_MA'] = calculate_moving_average(stock_data, short_window)
    stock_data['206_day_MA'] = calculate_moving_average(stock_data, long_window)
    stock_data['Signal'] = 0  # 1 for buy, -1 for sell

    position = 0  # 0 means no position, 1 means holding stock
    cash = initial_cash
    holdings = 0  # Number of shares held
    cash_history = []

    for i in range(1, len(stock_data)):
        if stock_data['3_day_MA'].iloc[i] > stock_data['206_day_MA'].iloc[i] and stock_data['3_day_MA'].iloc[i - 1] <= stock_data['206_day_MA'].iloc[i - 1]:
            if position == 0:  # Buy signal
                price = stock_data['Close'].iloc[i]
                holdings = (cash / price) * 0.995  # Subtract 0.5% fee
                cash = 0
                position = 1
                stock_data.at[stock_data.index[i], 'Signal'] = 1

        elif stock_data['3_day_MA'].iloc[i] < stock_data['206_day_MA'].iloc[i] and stock_data['3_day_MA'].iloc[i - 1] >= stock_data['206_day_MA'].iloc[i - 1]:
            if position == 1:  # Sell signal
                price = stock_data['Close'].iloc[i]
                cash = holdings * price * 0.995  # Subtract 0.5% fee
                holdings = 0
                position = 0
                stock_data.at[stock_data.index[i], 'Signal'] = -1

        # Update total cash value considering current holdings
        total_value = cash + holdings * stock_data['Close'].iloc[i]
        cash_history.append(total_value)

    # cash_history 길이가 stock_data와 일치하도록 조정
    if len(cash_history) < len(stock_data):
        last_value = cash_history[-1] if cash_history else initial_cash
        cash_history += [last_value] * (len(stock_data) - len(cash_history))

    stock_data['Total Value'] = cash_history
    return stock_data

def plot_interactive_chart(stock_data, rsi, moving_avg_20, moving_avg_50):
    # Create subplots
    fig = make_subplots(rows=3, cols=1, shared_xaxes=True,
                        vertical_spacing=0.2,
                        row_heights=[0.5, 0.2, 0.3])

    # Candlestick chart
    fig.add_trace(go.Candlestick(x=stock_data.index,
                                 open=stock_data['Open'],
                                 high=stock_data['High'],
                                 low=stock_data['Low'],
                                 close=stock_data['Close'],
                                 name='Candlestick'),
                  row=1, col=1)

    # Add moving averages
    fig.add_trace(go.Scatter(x=stock_data.index, y=moving_avg_20,
                             mode='lines', name='3-Day MA', line=dict(color='blue')),
                  row=1, col=1)
    fig.add_trace(go.Scatter(x=stock_data.index, y=moving_avg_50,
                             mode='lines', name='206-Day MA', line=dict(color='orange')),
                  row=1, col=1)

    # Plot Buy and Sell signals
    buy_signals = stock_data[stock_data['Signal'] == 1]
    sell_signals = stock_data[stock_data['Signal'] == -1]

    fig.add_trace(go.Scatter(x=buy_signals.index, y=buy_signals['Close'],
                             mode='markers', name='Buy Signal',
                             marker=dict(color='green', size=10, symbol='triangle-up')),
                  row=1, col=1)

    fig.add_trace(go.Scatter(x=sell_signals.index, y=sell_signals['Close'],
                             mode='markers', name='Sell Signal',
                             marker=dict(color='red', size=10, symbol='triangle-down')),
                  row=1, col=1)

    # RSI chart
    fig.add_trace(go.Scatter(x=rsi.index, y=rsi,
                             mode='lines', name='RSI', line=dict(color='purple')),
                  row=2, col=1)
    fig.add_shape(type="rect",
                 x0=stock_data.index[0], x1=stock_data.index[-1],
                 y0=70, y1=30,
                 line=dict(color="gray", width=2, dash='dash'),
                 fillcolor="lightgray", opacity=0.5,
                 row=2, col=1)

    # Portfolio value chart
    fig.add_trace(go.Scatter(x=stock_data.index, y=stock_data['Total Value'],
                             mode='lines', name='Portfolio Value', line=dict(color='black')),
                  row=3, col=1)

    # Update layout for better visuals
    fig.update_layout(title=f'{stock_data.index[-1].strftime("%Y-%m-%d")} Stock Price and RSI with Portfolio Value',
                      yaxis_title='Price',
                      xaxis_title='Date',
                      yaxis2_title='RSI',
                      yaxis3_title='Portfolio Value',
                      xaxis_rangeslider_visible=False,
                      showlegend=True)

    fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)

    fig.show()

def main():
    stock_symbol = 'TQQQ'  # 예시로 TQQQ 주식을 사용합니다. 원하는 주식의 심볼을 입력하세요.
    start_date = '2011-01-01'  # 시작 날짜
    end_date = '2024-08-15'  # 종료 날짜
    initial_cash = 10000  # 초기 자금

    stock_data = get_stock_price_data(stock_symbol, start_date, end_date)

    rsi = calculate_rsi(stock_data)
    moving_avg_20 = calculate_moving_average(stock_data, 3)
    moving_avg_50 = calculate_moving_average(stock_data, 206)

    stock_data = simulate_trading(stock_data, 3, 206, initial_cash)

    plot_interactive_chart(stock_data, rsi, moving_avg_20, moving_avg_50)

if __name__ == "__main__":
    main()


[*********************100%%**********************]  1 of 1 completed
