In [33]:
import random
import nltk
import scipy

import pandas as pd
import networkx as nx
import numpy as np

import matplotlib.pyplot as plt
import plotly.graph_objects as go

from nltk import word_tokenize, pos_tag, ne_chunk
from nltk.tree import Tree

import matplotlib.colors as mcolors
import random

nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('maxent_ne_chunker')
nltk.download('words')

import dash_bootstrap_components as dbc

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /Users/ramyaaprasath/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/ramyaaprasath/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package maxent_ne_chunker to
[nltk_data]     /Users/ramyaaprasath/nltk_data...
[nltk_data]   Package maxent_ne_chunker is already up-to-date!
[nltk_data] Downloading package words to
[nltk_data]     /Users/ramyaaprasath/nltk_data...
[nltk_data]   Package words is already up-to-date!


In [34]:
# Load the CSV files
df_ratings = pd.read_csv('data/Ratings.csv')
df_characters = pd.read_csv('data/Characters.csv', encoding='ISO-8859-1')

# Convert the 'name' column to a list
character_names = df_characters['name'].tolist()

# Function to extract names from a sentence
def extract_names(sentence):
    names = []
    for chunk in ne_chunk(pos_tag(word_tokenize(sentence))):
        if isinstance(chunk, Tree) and chunk.label() == 'PERSON':
            name = ' '.join([c[0] for c in chunk])
            # Only add the name if it's in the character_names list
            if name in character_names:
                names.append(name)
    return names

# Extract character names from the 'desc' column
df_ratings['names'] = df_ratings['desc'].apply(extract_names)

# Create a graph
G = nx.Graph()

# List of arcs and their episode ranges
arcs = [('East Blue Saga', range(1, 62+1)), 
    ('Alabasta Saga', range(62, 135+1)), 
    ('Sky Island Saga', range(135, 206+1)), 
    ('Water 7 Saga', range(206, 325+1)), 
    ('Thriller Bark Saga', range(325, 384+1)), 
    ('Summit War Saga', range(384, 516+1)), 
    ('Fish-Man Island Saga', range(516, 574+1)), 
    ('Dressrosa Saga', range(574, 746+1)), 
    ('Four Emperors Saga', range(746, 958+1)), 
    ('Wano Country Arc', range(958, 2000))]

# Create a dictionary mapping arcs to colors
arc_colors = {
    'East Blue Saga': '#ADD8E6',  # Light Blue
    'Alabasta Saga': '#FDBCB4',  # Pastel Orange
    'Sky Island Saga': '#D3D3D3',  # Whitish Gray
    'Water 7 Saga': '#00BFFF',  # Deep Sky Blue
    'Thriller Bark Saga': '#800080',  # Dark Purple
    'Summit War Saga': '#FF4500',  # Orange Red
    'Fish-Man Island Saga': '#800080',  # Grayish Purple
    'Dressrosa Saga': '#FF69B4',  # Hot Pink
    'Four Emperors Saga': '#40E0D0',  # Turquoise
    'Wano Country Arc': '#008000' # Green
}

# Create a dictionary mapping episode numbers to arcs
arc_dict = {ep: arc for arc, eps in arcs for ep in eps}

# Create a color map for nodes based on arcs
color_map = {node: arc_colors[arc_dict[node]] for node in G.nodes() if isinstance(node, int)}

# Add nodes for each character
for name in character_names:
    # Only add the character if they appear in at least one episode
    if df_ratings['names'].apply(lambda names: name in names).any():
        G.add_node(name)

# Add edges for each connection between a character and an episode
for _, row in df_ratings.iterrows():
    for name in row['names']:
        G.add_edge(name, row['episode_number'], color=row['year'])

# Get positions for the nodes in G
pos = nx.spring_layout(G)

# Adjust y position of nodes to represent timeline
for node in G.nodes():
    if isinstance(node, int):  # assuming episode numbers are integers
        pos[node] = (node, 0)  # place episode nodes on the x-axis
    else:  # assuming character names are strings
        pos[node] = (random.uniform(0, max(df_ratings['episode_number'])), random.uniform(0.1, 1))  # place character nodes above the x-axis

# Adjust y position of nodes to represent characters
for node in G.nodes():
    if isinstance(node, int):  # assuming episode numbers are integers
        pos[node] = (node, 0)  # place episode nodes on the x-axis
    else:  # assuming character names are strings
        # find the first episode where the character appears
        first_episode = df_ratings[df_ratings['names'].apply(lambda names: node in names)]['episode_number'].min()
        pos[node] = (first_episode, random.uniform(0.1, 1))  # place character nodes above the x-axis

def get_random_pastel_color():
    pastel_colors = mcolors.CSS4_COLORS
    pastel_colors = {name: mcolors.to_rgb(color) for name, color in pastel_colors.items() if "light" in name or "pastel" in name}
    pastel_colors = {name: color for name, color in pastel_colors.items() if color[0] < 0.8}  # Filter out colors that are too red
    return random.choice(list(pastel_colors.values()))


In [35]:
from dash import Dash
from dash.html import H1, H2, H3, Div
from dash_bootstrap_components import themes, Container, Row, Col, Button
from dash_core_components import Dropdown, Graph
from dash.dependencies import Input, Output, State

external_stylesheets = [
    themes.BOOTSTRAP,
    "https://fonts.googleapis.com/css2?family=Merriweather:ital,wght@0,300;0,400;0,700;0,900;1,300;1,400;1,700;1,900&display=swap"
]

app = Dash(__name__, external_stylesheets=external_stylesheets)

app.layout = Container([
    Row([
        Col([
            H2(
                'One Piece Character Connections', 
                style={
                    "font-family" : "'Merriweather', serif",
                    "font-weight" : "700",
                    "color" : "#313638",
                    "margin-top" : "4rem",
                    "margin-bottom" : "4rem",
                    "text-align" : "center"
                }
            ),
            Div([
                Graph(id='graph'),
            ], style={
                "width" : "80%",
                "margin" : "0 auto"
            }),
            H3(
                'Search a Character', 
                style={
                    "font-family" : "'Merriweather', serif",
                    "font-weight" : "700",
                    "color" : "#313638",
                    "margin-top" : "4rem",
                    "margin-bottom" : "2rem"
                }
            ),
            Div([
                Dropdown(
                    id='input-box',
                    options=[{'label': name, 'value': name} for name in character_names],
                    multi=True,
                    style={
                        "width" : "100%", 
                        "height" : "60px", 
                        "border-radius" : 0,
                        "border-color" : "#fcd5ce",
                    }
                ),
                Button(
                    'Submit', 
                    id='button',
                    style={
                        "font-family" : "'Merriweather', serif",
                        "font-weight" : "500",
                        "height" : "60px", 
                        "border-radius" : 0, 
                        "border-color" : "#ffb5a7", 
                        "background-color" : "#ffb5a7", 
                        "font-weight" : "bold",
                        "padding" : "0 4rem"
                    }
                )
            ], style={
                "display" : "flex", 
                "align-items" : 
                "center", "gap" : 
                "1rem"
            }),
        ]),
    ]),
], 
fluid=True, 
style={ 
    "width" : "100%", 
    "min-height" : "150vh", 
    "overflow" : "hidden",  
    "background-color" : "#f8edeb",  
    "padding" : "2rem 5rem"   
})

In [36]:
@app.callback(
    Output('graph', 'figure'),
    Input('button', 'n_clicks'),
    State('input-box', 'value'),
)
def update_output(n_clicks, value):
    # Create Figure
    fig = go.Figure()

    # Create a pastel color map for nodes
    colors = plt.cm.Pastel1(np.linspace(0, 1, len(G.nodes())))
    color_map = {node: colors[i] for i, node in enumerate(G.nodes())}

    if value == None or len(value) == 0:
        plotting_data = {
            'East Blue Saga': { 'X' : [], 'Y': [] },
            'Alabasta Saga': { 'X' : [], 'Y': [] },
            'Sky Island Saga': { 'X' : [], 'Y': [] },
            'Water 7 Saga': { 'X' : [], 'Y': [] },
            'Thriller Bark Saga': { 'X' : [], 'Y': [] },
            'Summit War Saga': { 'X' : [], 'Y': [] },
            'Fish-Man Island Saga': { 'X' : [], 'Y': [] },
            'Dressrosa Saga': { 'X' : [], 'Y': [] },
            'Four Emperors Saga': { 'X' : [], 'Y': [] },
            'Wano Country Arc': { 'X' : [], 'Y': [] },
        }

        for edge in G.edges():
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]

            episode_number = edge[0] if isinstance(edge[0], int) else edge[1]
            character = edge[0] if isinstance(edge[0], str) else edge[1]

            for i in arcs:
                if episode_number in i[1]:
                    episode_name = i[0]
                    break
            
            plotting_data[episode_name]['X'] += [x0, x1]
            plotting_data[episode_name]['Y'] += [y0, y1]
        
        for episode_name in plotting_data:
            edge_color = arc_colors[episode_name]
            edge_width = 0.5

            fig.add_trace(
                go.Scatter(
                    x=plotting_data[episode_name]['X'], 
                    y=plotting_data[episode_name]['Y'], 
                    mode='lines', 
                    line=dict(color=edge_color, width=edge_width), 
                    name=episode_name
                )
            )
    else:
        for edge in G.edges():
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]

            edge_x = [x0, x1]
            edge_y = [y0, y1]

            episode = edge[0] if isinstance(edge[0], int) else edge[1]
            character = edge[0] if isinstance(edge[0], str) else edge[1]

            if character in value:
                edge_color = 'rgba(255,0,0,0.5)'
                edge_width = 2
            else:
                edge_color = '#ced4da'
                edge_width = 0.5

            fig.add_trace(
                go.Scatter(
                    x=edge_x, 
                    y=edge_y, 
                    mode='lines', 
                    line=dict(color=edge_color, width=edge_width), 
                    hoverinfo='text', 
                    hovertext=f'Episode: {episode}, Character: {character}', 
                    showlegend=False
                )
            )

    # Add Nodes
    for node in G.nodes():
        node_x, node_y = pos[node]
        node_color = color_map[node]
        node_size = 2
        if value and node in value:
            node_color = mcolors.to_hex(get_random_pastel_color())
            node_size = 10
        if isinstance(node, int):  # assuming episode numbers are integers
            hovertext = f'Episode: {node}, Characters: {len(G[node])}'
        else:  # assuming character names are strings
            hovertext = f'Character: {node}, Episodes: {len(G[node])}'
        fig.add_trace(go.Scatter(x=[node_x], y=[node_y], mode='markers', hovertext=hovertext, marker=dict(color=node_color, size=node_size), showlegend=False))

    # Update layout
    fig.update_layout(
        font_family="Merriweather",
        font_color="#ffb5a7",
        showlegend=True,
        plot_bgcolor='#f8edeb',  # make plot background white
        paper_bgcolor='#f8edeb',  # make paper background white
        xaxis=dict(showgrid=True, gridcolor="#e3d5ca", showline=False, title='Episode Number'),  # remove x-axis grid lines and add title
        yaxis=dict(showgrid=True, gridcolor="#e3d5ca", showline=False),  # remove x-axis grid lines and add title
    )

    return fig

In [37]:
if __name__ == '__main__':
    app.run_server(debug=True, port=8080)