### A. Bar Graph

In [190]:
import plotly.express as px
import pandas as pd

# Read the CSV file
df = pd.read_csv('data/bar_assignment.csv')

# Transform 1 into "Yes" and 0 into "No"
df['legend'] = df['COUNT'].map({1: 'Yes', 0: 'No'})

# Create a crosstab
df_crosstab = pd.crosstab(df['LABEL'], df['legend'])

# Reset the index to convert the crosstab to a DataFrame
df_crosstab = df_crosstab.reset_index()

# Melt the DataFrame to long format
df_melted = df_crosstab.melt(id_vars='LABEL', value_vars=['No', 'Yes'], var_name='legend', value_name='COUNT')

# Create a horizontal stacked bar chart
fig = px.bar(df_melted, x='COUNT', 
             y='LABEL', color='legend', 
             orientation='h', 
             title='Label vs Count', 
             text='COUNT', 
             color_discrete_map={'No': 'red', 'Yes': 'blue'}
             )

# Update layout to place legend below the title
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="center",
        x=0.5
    )
)

# Update x and y labels
fig.update_xaxes(title_text='X-LABEL')
fig.update_yaxes(title_text='Y-LABEL')

# Automatically display the values inside the bars
fig.update_traces(texttemplate='%{text}', textposition='inside', textfont_size=14, textfont_color='white')

# Show the figure
fig.show()

# fig.write_image("img/bar_assignment.png")
# Image("img/bar_assignment.png")


### B. Sankey

In [189]:
import plotly.graph_objects as go
import pandas as pd

# Load the dataset from CSV
sankey_df = pd.read_csv('data/sankey_assignment.csv')

# Extract column names dynamically
sources = sankey_df.columns[1:-3].tolist()  # All except LABEL and last three columns
labels = sankey_df['LABEL'].tolist()
end_targets = sankey_df.columns[-3:].tolist()  # Last three columns as end targets

nodes = sources + labels + end_targets
node_map = {node: i for i, node in enumerate(nodes)}

# Define hex colors and map them to each category
hex_colors = [
    "#ff7f0e", "#1f77b4", "#d62728", "#2ca02c", "#8c564b", "#9467bd", "#7f7f7f", "#e377c2",
    "#17becf", "#bcbd22", "#98df8a", "#ffbb78", "#c5b0d5", "#ff9896", "#f7b6d2", "#c49c94"
]

# Assign a unique color to each node
node_colors = {node: hex_colors[i % len(hex_colors)] for i, node in enumerate(nodes)}

# Create links
links = []
colors = []  # Colors for flow visualization

# First set of links (from sources to LABELS)
for source in sources:
    for i, label in enumerate(sankey_df['LABEL']):
        value = sankey_df[source][i]
        if value > 0:
            links.append({
                "source": node_map[source],
                "target": node_map[label],
                "value": value
            })
            colors.append(node_colors[source])  # Use source node's color

# Second set of links (from LABELS to final categories)
for i, label in enumerate(sankey_df['LABEL']):
    for end_target in end_targets:
        value = sankey_df[end_target][i]
        if value > 0:
            links.append({
                "source": node_map[label],
                "target": node_map[end_target],
                "value": value
            })
            colors.append(node_colors[label])  # Use label node's color

# Generate Sankey diagram
fig = go.Figure(go.Sankey(
    node=dict(
        pad=10,  # Reduce padding between nodes
        thickness=15,  # Reduce thickness of nodes
        line=dict(color='black', width=0.3),
        label=nodes,
        color=[node_colors[node] for node in nodes]
    ),
    link=dict(
        source=[link['source'] for link in links],
        target=[link['target'] for link in links],
        value=[link['value'] for link in links],
        color=colors
    )
))

fig.update_layout(
    title_text="Sankey Diagram",
    font_size=14, # Set font size
    margin=dict(l=20, r=20, t=50, b=20)
)
fig.show()

# fig.write_image("img/sankey_assignment.png")
# Image("img/sankey_assignment.png")

### C. Network Graph