In [None]:
import pandas as pd
import numpy as np
import yfinance as yf
import matplotlib.pyplot as plt
from mplfinance.original_flavor import candlestick_ohlc
import matplotlib.dates as mdates

class yfinance():
    def __init__(self, ticker: str, start: str=None, end: str=None):
        self.ticker = ticker
        self.start = start
        self.end = end
        self.data_set = self.download()
        self.stock = self.Stock(self.data)
    
    def download(self):
        self.data_set = yf.Ticker(self.ticker)
        return self

    class Stock():
        def __init__(self, data: pd.DataFrame):
            self.data = data
            self.date = data.index
            self.prices = data['Close']
            self.models = None
            self.func = self.geometric_brownian_motion
        
        def close(self):
            self.prices = self.data['Close']
            return self
        def open(self):
            self.prices = self.data['Open']
            return self
        def high(self):
            self.prices = self.data['High']
            return self
        def low(self):
            self.prices = self.data['Low']
            return self
        
        def log(self):
            return  self.prices.apply(np.log)
        def diff(self):
            return self.prices.apply(np.diff)
        def log_diff(self):
            return self.prices.apply(np.log).apply(np.diff)
                
        def avg(self):
            return self.prices.mean().item()
        def std(self):
            return self.prices.std().item()
        def drift(self):
            return self.diff().dropna().mean().item()
        def vol_of_vol(self):
            return self.diff().dropna().std().item()
        
        def newest_price(self):
            return self.prices.iloc[-1].item()
        
        def simulation(self, n):
            models = []
            for i in range(n):
                S = self.func()
                models.append(S)
            self.models = models
            return self.models
        
        def plot_models(self):
            plt.figure(figsize=(12, 6))
            for model in self.models:
                plt.plot(self.date, model)
            plt.ylabel('price ($)')
            plt.grid(True)
            plt.show
            return
        
        def candlestick(self, day=200, moving_average=[5, 10, 20 ,50, 75, 100]):
            # prepare data
            candlestick_data = self.data.copy()
            candlestick_data = candlestick_data.reset_index()
            candlestick_data['Date'] = candlestick_data['Date'].map(mdates.date2num)
            candlestick_data = candlestick_data[-day:]

            fig, ax = plt.subplots(figsize=(12, 6))
            # illustrate a candlestick
            candlestick_ohlc(ax, candlestick_data[['Date', 'Open', 'High', 'Low', 'Close']].values, 
                            width=1, colorup='g', colordown='r')
            # add a moving average
            for ma in moving_average:
                candlestick_data[f'MA{ma}'] = candlestick_data['Close'].rolling(ma).mean()
                ax.plot(candlestick_data['Date'], candlestick_data[f'MA{ma}'], label=f'{ma} day moving average')

            ax.xaxis_date()
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
            plt.title(f'candlestick')
            plt.xlabel('Date')
            plt.ylabel('price (USD)')
            plt.legend()
            plt.grid(True)
            plt.show()