In [None]:
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import random
import yfinance as yf
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde
import requests
from io import StringIO
from bs4 import BeautifulSoup

In [None]:
def stock_selector(selected_group, max_tickers=4):
    """
    Text field for selecting stocks for analysis.
    
    User can either:
    - Enter up to `max_tickers` symbols separated by commas
    - Type "Random" to choose `max_tickers` symbols randomly from the selected group
    
    Runs basic analysis for the selected stocks. For SP500 shows Wiki info.
    """

    # Input from the group choice
    group_name = selected_group['stock_group']

    # Use the CSV to extract available symbols
    csv_file = f"{group_name}_tickers.csv"
    try:
        tickers_df = pd.read_csv(csv_file)
    except FileNotFoundError:
        print(f"\U0000274C CSV file for this group was not found. Make sure you have successfully selected the desired group.")
        return
    
    available_tickers = tickers_df['Symbol'].tolist()
    tickers_container = {"tickers": []}

    # Widgets (text field + confirmation button)
    input_text = widgets.Text(
        value='',
        placeholder=f"Enter up to {max_tickers} tickers, separated by commas, or type 'Random'/'Random N'.",
        description="Symbols for the analysis:",
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='630px', margin='10px 0 0 0')
    )

    button = widgets.Button(
        description="Choose symbols",
        button_style='success',
        layout=widgets.Layout(width='200px', height='40px', margin='10px 0 10px 0'),
        style={'font_weight': 'bold', 'font_size': '16px'}
    )

    output = widgets.Output()

    # Functions for the analysis

    # Unified titles
    def show_title(title):
        display(HTML(f"""
            <h3 style="
                text-align:center;
                margin:30px 0 10px 0;
                color:#2c3e50;
                font-weight:bold;
                font-size:25px;
            ">
                {title}
            </h3>
        """))

    # Wiki info for SP500
    def sp500_company_info(selected_tickers):
        url = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
        headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"}

        try:
            response = requests.get(url, headers=headers, timeout=10)
            response.raise_for_status()
        except Exception as e:
            print(f"\U0000274C Failed to download information from Wikipedia.")
            print(e)
            return

        # Get Wiki links
        soup = BeautifulSoup(response.text, "html.parser")
        table = soup.find("table", {"id": "constituents"})
        rows = table.find_all("tr")[1:]
        records = []
        
        for row in rows:
            cols = row.find_all("td")
            symbol = cols[0].text.strip()
            security_cell = cols[1]
            link_tag = security_cell.find("a")

            if link_tag and link_tag.get("href"):
                wiki_link = "https://en.wikipedia.org" + link_tag.get("href")
            else:
                wiki_link = None

            records.append({"Symbol": symbol, "Wiki": wiki_link})

        links_df = pd.DataFrame(records)

        # Other columns
        sp500 = pd.read_html(StringIO(response.text))[0]

        # Merge
        sp500 = sp500.merge(links_df, on="Symbol", how="left")
        sp500_selected = sp500[sp500["Symbol"].isin(selected_tickers)].copy()

        if sp500_selected.empty:
            print(f"\u26A0\ufe0f Failed to get Wiki links for the selected stocks.")
            return

        # Make links clickable
        def make_clickable(url):
            if pd.isna(url):
                return ""
            return f'<a href="{url}" target="_blank">\U0001F517 Wiki link</a>'

        sp500_selected["Wiki"] = sp500_selected["Wiki"].apply(make_clickable)

        show_title(f"\U0001F3E2 Basic Information about Selected S&P 500 Stocks")

        display(sp500_selected.style.set_properties(**{"text-align": "center","font-size": "14px"})
            .set_table_styles([{"selector": "th","props": [("text-align", "center"),("font-size", "14px"),("font-weight", "bold")]}
            ])
            .hide(axis="index")
        )

        display(HTML("<div style='height:50px;'></div>"))

    # Stock data preparation
    def prepare_price_and_returns(data, selected_tickers):
        price_dict = {}
        returns_dict = {}

        for t in selected_tickers:
            if len(selected_tickers) == 1:
                df = data.copy()
            else:
                df = data[t].copy()

            if isinstance(df.columns, pd.MultiIndex):
                df.columns = df.columns.get_level_values(-1)

            if "Close" in df.columns:
                close = df["Close"]
            else:
                close = df["Adj Close"]

            n_na_prices = close.isna().sum()
            if n_na_prices > 0:
                print(f"\n\u26A0\ufe0f For {t}: {n_na_prices} days with missing price values (NA) have been dropped. Take this into account when evaluating the results.\n")

            close = close.dropna()
            price_dict[t] = close
            returns_dict[t] = np.log(close / close.shift(1)).dropna()

        return price_dict, returns_dict

    # Price plots
    def plot_prices(series_dict, selected_tickers, title, date_filter=None):
        show_title(f"{title}")

        n = len(selected_tickers)
        
        if n == 1:
            fig, axes = plt.subplots(1, 1, figsize=(8, 6))
            axes = [axes]
        elif n == 2:
            fig, axes = plt.subplots(1, 2, figsize=(12, 5))
            axes = axes.flatten()
        else:
            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
            axes = axes.flatten()

        colors = ['#2ca02c', '#1f77b4', '#9467bd', '#d62728']

        for i, t in enumerate(selected_tickers):
            s = series_dict[t]

            if date_filter is not None:
                end_date = s.index.max()
                start_date = end_date - date_filter
                s = s.loc[start_date:end_date]

            axes[i].plot(s, color=colors[i % len(colors)])
            axes[i].set_title(t, fontsize=14, fontweight='bold')
            axes[i].set_xlabel("Date")
            axes[i].set_ylabel("Price")
            axes[i].tick_params(axis='x', rotation=90)

        for j in range(n, len(axes)):
            fig.delaxes(axes[j])

        plt.tight_layout()
        plt.show()

    # Descriptive stats
    def descr_stats_table(returns_dict, selected_tickers, title):
        stats_table = pd.DataFrame({
            "Mean":   [returns_dict[t].mean() * 100 for t in selected_tickers],
            "Max":    [returns_dict[t].max() * 100 for t in selected_tickers],
            "Min":    [returns_dict[t].min() * 100 for t in selected_tickers],
            "Q1":     [returns_dict[t].quantile(0.25) * 100 for t in selected_tickers],
            "Median": [returns_dict[t].median() * 100 for t in selected_tickers],
            "Q3":     [returns_dict[t].quantile(0.75) * 100 for t in selected_tickers],
            "Std. Dev": [returns_dict[t].std() * 100 for t in selected_tickers]
        }, index=selected_tickers)

        show_title(f"{title}")
        display(stats_table.style.format("{:.3f} %").set_table_attributes("style='margin-left:auto; margin-right:auto; font-size:18px;'"))

    # Returns disribution plots
    def plot_returns_distribution(returns_dict, selected_tickers, title):
        show_title(f"{title}")
        
        n = len(selected_tickers)
        
        if n == 1:
            fig, axes = plt.subplots(1, 1, figsize=(8, 5))
            axes = [axes]
        elif n == 2:
            fig, axes = plt.subplots(1, 2, figsize=(12, 5))
            axes = axes.flatten()
        else:
            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
            axes = axes.flatten()

        colors = ['#2ca02c', '#1f77b4', '#9467bd', '#d62728']

        for i, t in enumerate(selected_tickers):
            r = returns_dict[t].values

            axes[i].hist(r, bins=40, density=True, alpha=0.6, color=colors[i % len(colors)])

            kde = gaussian_kde(r)
            x_grid = np.linspace(r.min(), r.max(), 300)
            axes[i].plot(x_grid, kde(x_grid), linewidth=2)

            axes[i].set_title(t, fontsize=14, fontweight='bold')
            axes[i].set_xlabel("Return")
            axes[i].set_ylabel("Density")

        for j in range(n, len(axes)):
            fig.delaxes(axes[j])

        plt.tight_layout()
        plt.show()

    
    # After clicking the confirmation button
    def on_button_click(b):
        with output:
            output.clear_output()
            user_input = input_text.value.strip()

            if not user_input:
                print(f"\U0000274C No input inserted. Insert symbols or type 'Random'/'Random N'.")
                return

            # Random choice
            if user_input.lower().startswith("random"):
                parts = user_input.split()
                if len(parts) == 1:
                    n_random = max_tickers
                elif len(parts) == 2:
                    try:
                        n_random = int(parts[1])
                        if n_random < 1 or n_random > max_tickers:
                            print(f"\U0000274C Number of random tickers must be between 1 and {max_tickers}.")
                            return
                    except ValueError:
                        print(f"\U0000274C Invalid format! Use 'Random' or 'Random N' where N is between 1 and {max_tickers}.")
                        return
                else:
                    print(f"\U0000274C Invalid format! Use 'Random' or 'Random N' where N is between 1 and {max_tickers}.")
                    return

                if len(available_tickers) < n_random:
                    print(f"\U0000274C Sorry, but the CSV contains less than {n_random} symbols. Please choose different group.")
                    return
                selected_tickers = random.sample(available_tickers, n_random)
            else:
                tickers = [t.strip().upper() for t in user_input.split(",")]
                
                if len(tickers) > max_tickers:
                    print(f"\U0000274C Maximum of {max_tickers} tickers can be used.")
                    return
                
                invalid = [t for t in tickers if t not in available_tickers]
                if invalid:
                    print(f"\U0000274C The following tickers are not included in {group_name}: {', '.join(invalid)}\n \nYou can remove them, replace them with valid symbols from {group_name}_tickers.csv, or go back and try different group.")
                    return

                selected_tickers = tickers

            print(f"                             \U00002705 Stock symbols selected for the analysis are: {', '.join(selected_tickers)}")

            tickers_container["tickers"] = selected_tickers

            # Basic analysis
            data = yf.download(
                selected_tickers,
                period="2y",
                interval="1d",
                auto_adjust=True,
                group_by="ticker",
                progress=False
            )

            if data.empty:
                print(f"\U0000274C A problem occurred - No data downloaded.")
                return

            print(f"\n                                            \U0001F4E5 Price data successfully obtained.\n \n                 \U0001F680 Start your analysis by checking basic graphs of price movements and logarhitmic returns! \U0001F447")

            # Create .csv file for each ticker
            for t in selected_tickers:
                if len(selected_tickers) == 1:
                    df_to_save = data.copy()
                else:
                    df_to_save = data[t].copy()
                df_to_save.to_csv(f"{t}.csv")
            print(f"\n                                  \U0001F4C1 Individual CSV files saved for each selected ticker.")
            
            price_dict, returns_dict = prepare_price_and_returns(data, selected_tickers)

            # SP500 COMPANY INFO (only if SP500 selected)
            if group_name == "SP500":
                sp500_company_info(selected_tickers)

            # LAST 2 MONTHS PRICES
            plot_prices(
                price_dict,
                selected_tickers,
                title="\U0001F4C9 Daily Close Prices – 2 Months",
                date_filter=pd.Timedelta(days=60)
            )
            display(HTML("<div style='height:60px;'></div>"))

            # LAST 2 YEAR PRICES
            plot_prices(
                price_dict,
                selected_tickers,
                title="\U0001F4C8 Daily Close Prices – 2 Years"
            )
            display(HTML("<div style='height:60px;'></div>"))

            # DESCRIPTIVE STATS TABLE
            descr_stats_table(returns_dict, selected_tickers, title="\U0001F4CB Descriptive Statistics of Log Returns – 2 Years")
            display(HTML("<div style='height:60px;'></div>"))

            # RETURNS DISTRIBUTION
            plot_returns_distribution(returns_dict, selected_tickers, title="\U0001F4CA Distribution of Daily Log Returns - 2 Years")

    button.on_click(on_button_click)

    # Alignment
    ui = widgets.VBox(
        [input_text, button, output],
        layout=widgets.Layout(align_items='center')
    )

    display(ui)

    return tickers_container
    