In [None]:
#!pip install ipywidgets...

In [None]:
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output
import random
import yfinance as yf
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde

In [None]:
def stock_selector(selected_group, max_tickers=4):
    """
    Text field for selecting stocks for analysis.
    
    User can either:
    - Enter up to `max_tickers` symbols separated by commas
    - Type "Random" to select `max_tickers` tickers randomly from the CSV
    
    Validates input against the CSV with columns 'Symbol' and 'Security Name' created in the folder.
    """

    group_name = selected_group['stock_group']

    # Use the CSV
    csv_file = f"{group_name}_tickers.csv"
    try:
        tickers_df = pd.read_csv(csv_file)
    except FileNotFoundError:
        print(f"\U0000274C CSV file for this group was not found. Make sure you have successfully selected the desired group.")
        return
    
    available_tickers = tickers_df['Symbol'].tolist()

    # Widgets
    input_text = widgets.Text(
        value='',
        placeholder=f"Enter up to {max_tickers} tickers, separated by commas, or type 'Random'/'Random N'.",
        description="Symbols for the analysis:",
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='65%')
    )

    button = widgets.Button(
        description="Choose symbols",
        button_style='success',
        layout=widgets.Layout(width='200px', height='40px', margin='10px 0 10px 0'),
        style={'font_weight': 'bold', 'font_size': '16px'}
    )

    output = widgets.Output()

    # Funkce p≈ôi kliknut√≠ na tlaƒç√≠tko
    def on_button_click(b):
        with output:
            output.clear_output()
            user_input = input_text.value.strip()

            if not user_input:
                print(f"\U0000274C No input inserted. Insert symbols or type 'Random'/'Random N'.")
                return

            # Random v√Ωbƒõr
            if user_input.lower().startswith("random"):
                parts = user_input.split()
                if len(parts) == 1:
                    n_random = max_tickers
                elif len(parts) == 2:
                    try:
                        n_random = int(parts[1])
                        if n_random < 1 or n_random > max_tickers:
                            print(f"\U0000274C Number of random tickers must be between 1 and {max_tickers}.")
                            return
                    except ValueError:
                        print(f"\U0000274C Invalid format! Use 'Random' or 'Random N' where N is between 1 and {max_tickers}.")
                        return
                else:
                    print(f"\U0000274C Invalid format! Use 'Random' or 'Random N' where N is between 1 and {max_tickers}.")
                    return

                if len(available_tickers) < n_random:
                    print(f"\U0000274C Sorry, but the CSV contains less than {n_random} symbols. Please choose different group.")
                    return
                selected_tickers = random.sample(available_tickers, n_random)
            else:
                tickers = [t.strip().upper() for t in user_input.split(",")]
                
                if len(tickers) > max_tickers:
                    print(f"\U0000274C Maximum of {max_tickers} tickers can be used.")
                    return
                
                invalid = [t for t in tickers if t not in available_tickers]
                if invalid:
                    print(f"\U0000274C The following tickers are not included in {group_name}: {', '.join(invalid)}\n \nYou can remove them, replace them with valid symbols from {group_name}_tickers.csv, or go back and try different group.")
                    return

                selected_tickers = tickers

            print(f"\U00002705 Stock symbols selected for the analysis are: {', '.join(selected_tickers)}")

            # Basic analysis
            data = yf.download(
                selected_tickers,
                period="1y",
                interval="1d",
                auto_adjust=True,
                group_by="ticker",
                progress=False
            )

            if data.empty:
                print("‚ùå No data downloaded.")
                return

            print("\nüì• Price data successfully downloaded.\n")

            # =========================
            # PREPARE STRUCTURES
            # =========================
            price_dict = {}
            returns_dict = {}

            for t in selected_tickers:
                if len(selected_tickers) == 1:
                    close = data["Close"].dropna()
                else:
                    close = data[t]["Close"].dropna()

                price_dict[t] = close
                returns_dict[t] = np.log(close / close.shift(1)).dropna()

            # =========================
            # ---- 1) LAST MONTH PRICES
            # =========================
            fig, axes = plt.subplots(2, 2, figsize=(12, 8))
            axes = axes.flatten()

            for i, t in enumerate(selected_tickers):
                end_date = price_dict[t].index.max()
                start_date = end_date - pd.Timedelta(days=30)
                last_month = price_dict[t].loc[start_date:end_date]
                axes[i].plot(last_month)
                axes[i].set_title(f"{t} ‚Äì Close (Last Month)")
                axes[i].set_xlabel("Date")
                axes[i].set_ylabel("Price")
                axes[i].tick_params(axis='x', rotation=90)

            for j in range(i + 1, 4):
                axes[j].axis("off")

            plt.suptitle("Daily Close Prices ‚Äì Last Month", fontsize=14)
            plt.tight_layout()
            plt.show()

            # =========================
            # ---- 2) LAST YEAR PRICES
            # =========================
            fig, axes = plt.subplots(2, 2, figsize=(12, 8))
            axes = axes.flatten()

            for i, t in enumerate(selected_tickers):
                axes[i].plot(price_dict[t])
                axes[i].set_title(f"{t} ‚Äì Close (1 Year)")
                axes[i].set_xlabel("Date")
                axes[i].set_ylabel("Price")
                axes[i].tick_params(axis='x', rotation=90)

            for j in range(i + 1, 4):
                axes[j].axis("off")

            plt.suptitle("Daily Close Prices ‚Äì Last Year", fontsize=14)
            plt.tight_layout()
            plt.show()

            # =========================
            # ---- 3) DESCRIPTIVE STATS TABLE
            # =========================
            stats_table = pd.DataFrame({
                "Mean":   [returns_dict[t].mean() for t in selected_tickers],
                "Max":    [returns_dict[t].max() for t in selected_tickers],
                "Min":    [returns_dict[t].min() for t in selected_tickers],
                "Q1":     [returns_dict[t].quantile(0.25) for t in selected_tickers],
                "Median": [returns_dict[t].median() for t in selected_tickers],
                "Q3":     [returns_dict[t].quantile(0.75) for t in selected_tickers],
                "Std. Dev": [returns_dict[t].std() for t in selected_tickers]
            }, index=selected_tickers)

            display(stats_table.style.format("{:.4f}"))

            # =========================
            # ---- 4) RETURNS DISTRIBUTION
            # =========================
            fig, axes = plt.subplots(2, 2, figsize=(12, 8))
            axes = axes.flatten()

            for i, t in enumerate(selected_tickers):
                r = returns_dict[t].values

                # Histogram (density)
                axes[i].hist(r, bins=40, density=True, alpha=0.6)

                # KDE line
                kde = gaussian_kde(r)
                x_grid = np.linspace(r.min(), r.max(), 300)
                axes[i].plot(x_grid, kde(x_grid), linewidth=2)

                axes[i].set_title(f"{t} ‚Äì Log Returns Distribution")
                axes[i].set_xlabel("Return")
                axes[i].set_ylabel("Density")

            for j in range(i + 1, 4):
                axes[j].axis("off")

            plt.suptitle("Distribution of Daily Log Returns (Histogram + KDE)", fontsize=14)
            plt.tight_layout()
            plt.show()

    button.on_click(on_button_click)

    # UI layout
    ui = widgets.VBox(
        [input_text, button, output],
        layout=widgets.Layout(align_items='center')
    )

    display(ui)