In [None]:
# Uncomment this cell to install dependencies
# !pip install plotly==4.14.1
# !pip install python-igraph

In [None]:
import csv
import ast
import plotly
import plotly.graph_objects as go
from platform import python_version
print(python_version())
import random
import igraph
from igraph import Graph, EdgeSeq

In [None]:
plotly.__version__

In [None]:
igraph.__version__

In [None]:
# Change csv to load here
with open("./data/Books.csv","r") as f:
    file = csv.reader(f)
    file_content = []
    for row in file:
        file_content.append(row)
f.close()
print(len(file_content))
header = file_content[0]
data = file_content[1:]
print("all data: ", len(data))
print(header)

In [None]:
def get_index(name):
    for i in range(len(header)):
        if header[i]==name:
            return i
    return -1

In [None]:
# One can change the foucs node Id here
focus_node_id = 2
focus_node_name = data[focus_node_id][get_index('name')]
print(focus_node_name)

In [None]:
# Function to filter list of alsos, picks top 20 if more is available with greater than 2 common products
def filter_alsos(all_alsos):
    if len(all_alsos)<=10:
        return all_alsos
    else:
        filtered = []
        for i in all_alsos:
            if i[1] >= 2:
                filtered.append(i)
            else:
                break
            if len(filtered) >= 20:
                break
        print(len(filtered))
        return filtered

In [None]:
all_alsos = data[focus_node_id][get_index('also')]
all_alsos = ast.literal_eval(all_alsos) 
filt_alsos = filter_alsos(all_alsos)
filt_alsos

In [None]:
# Function to get path of nodes for plotting
def get_path_nodes(nodeid):
    if nodeid == 0:
        return []
    else:
        prevs = []
        current_parent = int(data[nodeid][get_index('parent')])
        prevs.append(current_parent)
        prevs.extend(get_path_nodes(current_parent))
        return prevs

In [None]:
all_nodes = []
all_paths = {}

all_nodes.append(focus_node_id)
all_paths[focus_node_id] = get_path_nodes(focus_node_id)

for also in filt_alsos:
    temp = get_path_nodes(also[0])
    all_nodes.append(also[0])
    all_nodes.extend(temp)
    all_paths[also[0]] = temp
    
print("All nodes with duplicates: ", len(all_nodes))
all_nodes = list(set(all_nodes))
print("All unique nodes: ", len(all_nodes))

In [None]:
all_nodes

In [None]:
all_depth = []
for node in all_nodes:
    all_depth.append(len(get_path_nodes(node)))

In [None]:
def get_common_count(nodeid):
    focus_node_also_list = data[focus_node_id][get_index('also')]
    focus_node_also_list = ast.literal_eval(focus_node_also_list) 
    for anode in focus_node_also_list:
        if anode[0] == nodeid:
            return anode[1]
    return 0

In [None]:
all_com_counts = []
for node in all_nodes:
    all_com_counts.append(get_common_count(node))

In [None]:
all_product_counts = []
for node in all_nodes:
    all_product_counts.append(int(data[node][get_index('productCount')]))

In [None]:
category_names = []
for node in all_nodes:
    category_names.append(data[node][get_index('name')])

In [None]:
all_paths

In [None]:
# sanity check
for node in all_paths:
    assert node in all_nodes
    for i in all_paths[node]:
        assert i in all_nodes

In [None]:
# Instantiate graph and add vertices
g = Graph()
g.add_vertices(len(all_nodes))
print(g)

In [None]:
# Add vertices attributes
g.vs["cat_id"] = all_nodes
g.vs["depth"] = all_depth
g.vs["count"] = all_com_counts
g.vs["nprod"] = all_product_counts
g.vs["cat_name"] = category_names

In [None]:
# Add edges
all_edges = []
for node in all_paths:
    seq = []
    seq.append(node)
    seq.extend(all_paths[node])
    for i in range(len(seq)-1):
        node1 = g.vs.find(cat_id=seq[i]).index
        node2 = g.vs.find(cat_id=seq[i+1]).index
        all_edges.append((node1, node2))
#         print("Adding edge b/w ",(seq[i], seq[i+1])," with indices ",(node1, node2))

In [None]:
g.add_edges(all_edges)
print(g)

In [None]:
# Find depth to compute layout
focus_index = g.vs.find(cat_id=focus_node_id).index
focus_depth = g.vs[focus_index]["depth"]
max_depth = -1
for v in g.vs:
    if max_depth < v["depth"]:
        max_depth = v["depth"]

In [None]:
# Here the layout is computed
# counter update and initialization can be changed to spacing
position = {}
for d in range(max_depth+1):
    cat_count = {}
    for v in g.vs:
        if v["depth"]==d:
            cat_count[v.index] = v["count"]
    cat_count = dict(sorted(cat_count.items(), key=lambda item: item[1], reverse=True))
    print(cat_count)
    # Dicts preserve insertion order in Python 3.7+
    counter = 60000
    for cat in cat_count:
        if cat in position:
            print("Some trouble")
        if g.vs[cat]["cat_id"]==focus_node_id:
            print("Focus node positioned")
            position[cat] = [0,0]
        else:
            position[cat] = [counter*1.0 , (focus_depth - d)*70000]
            counter +=40000

In [None]:
es = EdgeSeq(g) # sequence of edges
E = [e.tuple for e in g.es] # list of edges

L = len(position)
Xn = [position[k][0] for k in range(L)]
Yn = [position[k][1] for k in range(L)]
Xe = []
Ye = []
for edge in E:
    Xe+=[position[edge[0]][0],position[edge[1]][0], None]
    Ye+=[position[edge[0]][1],position[edge[1]][1], None]

In [None]:
labels = all_nodes

In [None]:
# Here the size of nodes is decided
all_sizes = []
min_size = 10
all_annots = []
for v in g.vs:
    all_annots.append(v["count"])
    all_sizes.append((v["count"]/0.2)+min_size) # The divinding factor can be used to control size
all_sizes[g.vs.find(cat_id=focus_node_id).index] = 1.2*(max(all_sizes))

In [None]:
# Function to make annotations
def make_annotations(pos, also_count, font_size=10, font_color='rgb(250,250,250)'):
    L=len(pos)
    if len(also_count)!=L:
        raise ValueError('The lists pos and text must have the same len')
    annotations = []
    for k in range(L):
        if g.vs[k]["cat_id"] == focus_node_id:
            annotations.append(
                dict(
                    text="Focus Node", 
                    x=pos[k][0], y=pos[k][1],
                    xref='x1', yref='y1',
                    font=dict(color=font_color, size=font_size),
                    showarrow=False)
            )
        else:
            annotations.append(
                dict(
                    text=also_count[k], 
                    x=pos[k][0], y=pos[k][1],
                    xref='x1', yref='y1',
                    font=dict(color=font_color, size=font_size),
                    showarrow=False)
            )
    return annotations

In [None]:
# Sanity Check
i = 0
for v in g.vs:
    assert all_product_counts[i]==v["nprod"]
    assert category_names[i]==v["cat_name"]
    i +=1

In [None]:
# Compute percentage similarity
all_sim_ratios = []
for i in range(len(all_product_counts)):
    if all_product_counts[i] == 0:
        all_sim_ratios.append(0.0)
    else:
        all_sim_ratios.append((all_annots[i]/all_product_counts[i])*100)
all_sim_ratios

In [None]:
temp = all_product_counts
sorted(temp)

In [None]:
# Create list of colors, some will be picked
number_of_colors = len(all_edges)
edge_color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
             for i in range(number_of_colors)]
print(edge_color)

In [None]:
# Create hover text 
maybe_text = []
assert len(category_names) == len(all_product_counts)
for i in range(len(category_names)):
    value_on_hover = category_names[i] + "<br>" + str(all_product_counts[i])
    maybe_text.append(value_on_hover)

In [None]:
all_sizes

In [None]:
# Add traces and create visualization
color_counters = []
for i in range(max_depth):
    color_counters.append(0)

fig = go.Figure()

fig.add_trace(go.Scatter(x=Xn,
                  y=Yn,
                  mode='markers',
                  name='Category',
                  marker=dict(symbol='circle-dot',
                                size=all_sizes,
#                                 cmin = 0,
#                                 cmax = 5000,
                                color=all_sim_ratios,    #'#DB4551',
                                colorbar=dict(
                                    title="% common"
                                    ),
                                colorscale="bluered_r", #sequential colorscale
                                line=dict(color='rgb(0,0,0)', width=0.3),
                                opacity=1.0
                                ),
                  hovertemplate =
                    '<b>%{text}</b><extra></extra>',
                  text = maybe_text,
                  showlegend=False,      
                  ))

# Add focus node seperately
index_focus_node = g.vs.find(cat_id=focus_node_id).index
fig.add_trace(go.Scatter(x=[Xn[index_focus_node]],
                  y=[Yn[index_focus_node]],
                  mode='markers',
                  name='Category',
                  marker=dict(symbol='circle-dot',
                                size=[all_sizes[index_focus_node]],
                                color='rgb(0,0,0)',    #'#DB4551',
                                line=dict(color='rgb(0,0,0)', width=0.1),
                                opacity=1.0
                                ),
                  hovertemplate =
                    '<b>%{text}</b><extra></extra>',
                  text = [maybe_text[index_focus_node]],
                  showlegend=False,      
                  ))


for i in range(len(all_edges)):
    node1 = all_edges[i][0]
    node2 = all_edges[i][1]
    d1 = g.vs[node1]["depth"]
    d2 = g.vs[node2]["depth"]
    edge_level = min(d1,d2)
    fig.add_trace(go.Scatter(x=Xe[3*i:3*i+3],
                   y=Ye[3*i:3*i+3],
                   mode='lines',
                   line=dict(color=edge_color[color_counters[edge_level]], width=0.7),
                   hoverinfo='none',
                   showlegend=False,
                   ))
    color_counters[edge_level] = color_counters[edge_level]+1

In [None]:
axis = dict(showline=False, 
            zeroline=False,
            showgrid=False,
            showticklabels=False,
            )

fig.update_layout(title= 'My Visualization',
              annotations=make_annotations(position, all_annots),
              xaxis=axis,
              yaxis=axis,
              hoverlabel_align='left'
              )
fig.show()

In [None]:
# Save to file, this is more clear
fig.write_html("temp.html")

In [None]:
print(g)