In [None]:
# Import libraries
import sys
import os
import numpy as np
import pandas as pd
import yfinance as yf
from sqlalchemy import create_engine, text

In [None]:
class DownloadData:
        def __init__(self, filename, filename2, filename3):
                self.symbols = None
                self.symbols_execute = None
                self.engine = create_engine(f"sqlite:///{filename}", echo=True)
                self.file_index = filename2
                self.file_market_cap = filename3

        def scrape_wiki_table(self, url):
                # Read all tables from the page
                tables = pd.read_html(url)

                # Print number of tables and preview first few rows of each
                #for i, table in enumerate(tables):
                #        print(f"\nTable {i}:\n")
                #        print(table.head())

                # Return the one that contains the actual list
                # Based on current structure, it's usually table index 1 (or 0 depending on content)
                for table in tables:
                        if 'Symbol' in table.columns:
                                self.symbols = table["Symbol"].dropna().astype(str).str.strip().tolist()
                                self.symbols_execute = [symbol + ".NS" for symbol in self.symbols]
                                print("Symbols:", self.symbols)
                                print("Symbols for download:", self.symbols_execute)
                                return self.symbols_execute
                                break  # optional: stop after finding the first valid table


        # SQL query to create table if not exists
        def create_table(self):

                drop_table_query = "DROP TABLE IF EXISTS stock_data;"
                create_table_query = """
                CREATE TABLE IF NOT EXISTS stock_data (
                        Ticker TEXT NOT NULL,
                        Date TEXT NOT NULL,
                        Open REAL,
                        High REAL,
                        Low REAL,
                        Close REAL,
                        Volume INTEGER,
                        PRIMARY KEY (Ticker, Date)
                );
                """

                # Execute the table creation
                with self.engine.connect() as conn:
                        conn.execute(text(drop_table_query))
                        conn.execute(text(create_table_query))
                        print("Table created successfully.")


        # Fetch and store all stock data
        def fetch_and_store_all(self, list_index):
                for symbol in list_index:
                        try:
                                # Download historical data using yfinance
                                df = yf.download(symbol, start="2000-01-01", end="2024-12-31", group_by=None)
                                if df.empty:
                                        print(f"No data for {symbol}")
                                        continue

                                df.columns = [col if isinstance(col, str) else col[1] for col in df.columns.to_flat_index()]
                                df = df.reset_index()  # Ensure 'Date' is a column
                                df["Ticker"] = symbol.replace(".NS", "").upper()
                                df = df[["Ticker", "Date", "Open", "High", "Low", "Close", "Volume"]]

                                # Insert into database
                                df.to_sql("stock_data", self.engine, if_exists="append", index=False)
                                print(f"Saved: {symbol}")

                        except Exception as e:
                                print(f"Error with {symbol}: {e}")

        # Query: All Stocks
        def query_all_stocks(self):
                query = text("""
                        SELECT Ticker, Date, Close
                        FROM stock_data
                        ORDER BY Ticker, Date
                """)

                with self.engine.connect() as conn:
                        data = pd.read_sql(query, conn)

                # Preprocess
                data["Date"] = pd.to_datetime(data["Date"])
                data.set_index("Date", inplace=True)

                # Pivot: rows = date, columns = ticker, values = close
                data = data.pivot(columns="Ticker", values="Close")


                result = data.copy()

                for col in result.columns:
                    first_valid = result[col].first_valid_index()
                    if first_valid is not None:
                        # Only interpolate values *after* the first valid index
                        result.loc[first_valid:, col] = result.loc[first_valid:, col].interpolate(method='linear', limit_direction='both')

                result = result.mask(result < 1e-10, 1e-10)
                return result

        def read_indices(self):
                ################################################################################################################
                '''
                #yfinance
                df = yf.download(self.ticker, start=self.start, end=self.end)

                '''
                # Read XLSX and parse the 'Date' column as datetime
                filepath = self.file_index
                sheet = 'Sheet1'
                df = pd.read_excel(filepath, sheet_name=sheet, skiprows=[1, 2])

                df.rename(columns={'Price': 'Date'}, inplace=True)
                df.set_index('Date', inplace=True)

                ################################################################################################################

                result = df.copy()

                for col in result.columns:
                    first_valid = result[col].first_valid_index()
                    if first_valid is not None:
                        # Only interpolate values *after* the first valid index
                        result.loc[first_valid:, col] = result.loc[first_valid:, col].interpolate(method='linear', limit_direction='both')

                result = result.mask(result < 1e-10, 1e-10)
                return result

        def read_market_cap(self):
                ################################################################################################################
                '''
                #yfinance
                # List of Nifty 50 tickers
                tickers = [
                    'ADANIENT.NS', 'ADANIPORTS.NS', 'APOLLOHOSP.NS', 'ASIANPAINT.NS', 'AXISBANK.NS',
                    'BAJAJ-AUTO.NS', 'BAJFINANCE.NS', 'BAJAJFINSV.NS', 'BEL.NS', 'BHARTIARTL.NS',
                    'CIPLA.NS', 'COALINDIA.NS', 'DRREDDY.NS', 'EICHERMOT.NS', 'ETERNAL.NS',
                    'GRASIM.NS', 'HCLTECH.NS', 'HDFCBANK.NS', 'HDFCLIFE.NS', 'HEROMOTOCO.NS',
                    'HINDALCO.NS', 'HINDUNILVR.NS', 'ICICIBANK.NS', 'INDUSINDBK.NS', 'INFY.NS',
                    'ITC.NS', 'JIOFIN.NS', 'JSWSTEEL.NS', 'KOTAKBANK.NS', 'LT.NS',
                    'M&M.NS', 'MARUTI.NS', 'NESTLEIND.NS', 'NTPC.NS', 'ONGC.NS',
                    'POWERGRID.NS', 'RELIANCE.NS', 'SBILIFE.NS', 'SHRIRAMFIN.NS', 'SBIN.NS',
                    'SUNPHARMA.NS', 'TCS.NS', 'TATACONSUM.NS', 'TATAMOTORS.NS', 'TATASTEEL.NS',
                    'TECHM.NS', 'TITAN.NS', 'TRENT.NS', 'ULTRACEMCO.NS', 'WIPRO.NS'
                ]

                # Fetch and store market cap
                data = []

                for ticker in tickers:
                    stock = yf.Ticker(ticker)
                    try:
                        cap = stock.info['marketCap']
                    except:
                        cap = None
                    data.append((ticker, cap))
                '''
                # Read XLSX
                filepath = self.file_market_cap
                sheet = 'Sheet1'
                df = pd.read_excel(filepath, sheet_name=sheet)
                ################################################################################################################

                df['Ticker'] = df['Ticker'].str.replace('.NS', '', regex=False)

                # Calculate total market cap
                total_market_cap = df['MarketCap'].sum()

                # Add a column for weight
                df['weight'] = df['MarketCap'] / total_market_cap

                return df

In [None]:
'''
if __name__ == "__main__":
    import os

    filename = 'donkey.db'
    url = 'https://en.wikipedia.org/wiki/NIFTY_50'

    try:
        os.remove(filename)
        print("Old DB removed.")
    except FileNotFoundError:
        print("No previous DB file.")
    except Exception as e:
        print(f"Error deleting DB: {e}")

    dd = DownloadData(filename)
    symbols = dd.scrape_wiki_table(url)
    dd.create_table()
    dd.fetch_and_store_all([symbols[0], symbols[1]])  # Limit to first 2 for quick test
    print (dd.query_all_stocks())

'''
