In [2]:
from sqlalchemy import create_engine
import pandas as pd
import numpy as np
import logging
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from tabulate import tabulate
import os
import sys
from pyspark.sql import SparkSession


sys.path.append(os.path.abspath(os.path.join('..', 'apps')))
from data_loader import load_data_spark

from custom_plotly import set_custom_template
set_custom_template()

from keyword_analysis import analyze_keywords

24/10/27 11:44:39 INFO SharedState: Setting hive.metastore.warehouse.dir ('null') to the value of spark.sql.warehouse.dir.
24/10/27 11:44:39 INFO SharedState: Warehouse path is 'file:/home/tron/git/project_gemma/utilities_ipynb/spark-warehouse'.


In [3]:
# Define custom layout parameters
def get_custom_layout():
    return {
        "plot_bgcolor": "black",  # Set the background color to black
        "paper_bgcolor": "black",  # Set the paper background color to black
        "font": {"family": "Arial", "size": 14, "color": "#FFFFFF"},  # Font settings
        "title": {"x": 0.5},  # Center-align title
        "margin": {"t": 60, "b": 40, "l": 40, "r": 40},  # Tight layout margins
        "xaxis": {
            "title_font": {"size": 12},
            "tickfont": {"size": 12},
            "showline": True,  # Show the horizontal line
            "zeroline": False,  # Hide the vertical line
            "gridcolor": "black"  # Color of the grid lines
        },
        "yaxis": {
            "title_font": {"size": 12},
            "tickfont": {"size": 12},
            "showline": True,  # Show the horizontal line
            "zeroline": False,  # Hide the vertical line
            "gridcolor": "gray"  # Color of the grid lines
        }
    }

# Set custom template based on 'plotly_dark' with updated layout
custom_layout = {
    "layout": get_custom_layout()
}

# Update the Plotly template
pio.templates["custom_dark"] = pio.templates["plotly_dark"].update(custom_layout)
pio.templates.default = "custom_dark"  # Use custom template as default
pio.renderers.default = 'notebook'  # Render inline in notebook

In [4]:
# Add the path to the apps directory
sys.path.append(os.path.abspath(os.path.join('..', 'apps')))  # Adjust the path based on your structure
from data_loader import load_data_spark  # Now you can import your function

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set up environment variables for database connection
os.environ['DB_USER'] = 'postgres'
os.environ['DB_PASSWORD'] = 'password'
os.environ['DB_HOST'] = 'localhost'
os.environ['DB_PORT'] = '5432'
os.environ['DB_NAME'] = 'project_gemma'



# Add the path to the apps directory
sys.path.append(os.path.abspath(os.path.join('..', 'apps')))  # Adjust the path based on your structure

# Define the analyze_keywords function with fine-tuning sliders
def analyze_keywords(keywords=None):
    jdbc_url = 'jdbc:postgresql://localhost:5432/project_gemma'
    table_name = 'cagliostro_gutenberg'
    df = load_data_spark(jdbc_url, table_name)

    if df is None or df.empty:
        logging.warning("Loaded DataFrame is empty or None.")
        return pd.DataFrame()

    # Ensure 'paragraph' column exists
    if 'paragraph' not in df.columns or df['paragraph'].isnull().all():
        logging.error("'paragraph' column is missing or contains only NaN values.")
        return pd.DataFrame()

    # Count occurrences of each keyword in the paragraphs
    keyword_counts = {keyword: 0 for keyword in keywords} if keywords else {}
    for paragraph in df['paragraph'].dropna():
        for keyword in keywords:
            keyword_counts[keyword] += paragraph.lower().count(keyword.lower())

    # Convert counts to DataFrame and filter out zero-count keywords
    keyword_counts_df = pd.DataFrame(keyword_counts.items(), columns=['keyword', 'count'])
    keyword_counts_df = keyword_counts_df[keyword_counts_df['count'] > 0]

    # Sort by count for consistent slider operation
    keyword_counts_df = keyword_counts_df.sort_values(by='count').reset_index(drop=True)

    # Create frames for each count threshold
    max_count = int(keyword_counts_df['count'].max())
    frames = []
    for i in range(max_count + 1):
        filtered_df = keyword_counts_df[keyword_counts_df['count'] <= i]
        frames.append(go.Frame(
            data=[go.Bar(
                x=filtered_df['keyword'],
                y=filtered_df['count'],
                text=filtered_df['count'],
                marker=dict(
                    color=filtered_df['count'],
                    coloraxis="coloraxis"  # Link to the global color axis
                )
            )],
            name=str(i)
        ))

    # Set up the initial figure
    fig = go.Figure(
        data=[go.Bar(
            x=keyword_counts_df['keyword'],
            y=keyword_counts_df['count'],
            text=keyword_counts_df['count'],
            marker=dict(
                color=keyword_counts_df['count'],
                coloraxis="coloraxis"
            )
        )],
        frames=frames
    )

    # Add a 'jet' color scale to the layout
    fig.update_layout(
        title="Keyword Frequency with Adjustable Count Threshold and Jet Color Scale",
        xaxis_title="Keyword",
        yaxis_title="Count",
        font=dict(family="Arial", size=14, color="white"),
        plot_bgcolor="black",
        paper_bgcolor="black",
        coloraxis=dict(
            colorscale="Jet",  # Jet color scale for the color axis
            colorbar=dict(title="Count", tickvals=[0, max_count])
        ),
        sliders=[{
            "active": max_count,
            "currentvalue": {"prefix": "Max Count Threshold: ", "font": {"color": "white", "size": 14}},
            "pad": {"t": 50},
            "steps": [
                {"method": "animate", "label": str(i), 
                 "args": [[str(i)], {"frame": {"duration": 2000, "redraw": True}, "mode": "immediate", "transition": {"duration": 200, "easing": "cubic-in-out"}}]}
                for i in range(max_count + 1)
            ]
        }]
    )

    fig.update_traces(texttemplate='%{text}', textposition='outside')
    fig.show()
    logging.info("Keyword frequency analysis with native slider and jet color scale is complete.")

# Usage
keywords = ['Cagliostro',
            #'brother',
            'son',
            'father',
            #'poor',
            'die',
            'rich',
            'very']

keyword_analysis = analyze_keywords(keywords)
print(tabulate(keyword_analysis, headers='keys', tablefmt='psql'))

24/10/27 11:44:44 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
2024-10-27 11:44:44,889 - ERROR - Error loading data: An error occurred while calling o29.jdbc.
: java.lang.ClassNotFoundException: org.postgresql.Driver
	at java.base/java.net.URLClassLoader.findClass(URLClassLoader.java:476)
	at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:594)
	at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:527)
	at org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry$.register(DriverRegistry.scala:46)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.$anonfun$driverClass$1(JDBCOptions.scala:103)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.$anonfun$driverClass$1$adapted(JDBCOptions.scala:103)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.<init>(JDBCOptions.scala:103)
	at org.apache.spark.sql.execution.datasources.jdb


