## Setup

In [1]:
import pandas as pd
import numpy as np
import os
import sys
from langchain.document_loaders import PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader
from langchain.document_loaders import PyPDFDirectoryLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pathlib import Path
import random
import logging

# Change to project directory
os.chdir("..")

from llmgrapher.helpers.df_helpers import df2Graph, graph2Df, documents2Dataframe

[32m2024-05-12 23:08:41.257[0m | [1mINFO    [0m | [36mllmgrapher.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/kleptotrace/Projects/git/llmgrapher[0m
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Logger setup
logger = logging.getLogger("__name__")
logger.setLevel(logging.DEBUG) # only log the current logger in DEBUG mode
logging.basicConfig(
    level=logging.CRITICAL,  # show only CRITICAL messages from other loggers
    format="%(asctime)s [%(levelname)s] %(message)s",
    encoding="utf-8",
    handlers=[   # output log messages to both file and stdout
        logging.FileHandler("logs/extract_graph_notebook.log"),
        logging.StreamHandler(sys.stdout)
    ]
)

In [3]:
# Directory Setup
# Input data directory
data_dir = "ag_news"
inputdirectory = Path(f"./data/raw/input/{data_dir}")
# This is where the output csv files will be written
out_dir = data_dir
outputdirectory = Path(f"./data/interim/output/{out_dir}")

In [4]:
ag_news = pd.read_csv("data/raw/ag_news_train.csv", header=None)

In [5]:
# ag_news_txt = ag_news[2].str.cat(sep="\n")
ag_news_txt = ag_news[2].iloc[:50].str.cat(sep="\n")

In [6]:
with open("data/raw/input/ag_news/ag_news.txt", "w") as f:
    f.write(ag_news_txt)

## Load Documents

In [7]:
## Dir PDF Loader
# loader = PyPDFDirectoryLoader(inputdirectory)
## File Loader
# loader = PyPDFLoader("./data/MedicalDocuments/orf-path_health-n1.pdf")
loader = DirectoryLoader(inputdirectory, show_progress=True)
documents = loader.load()

splitter = RecursiveCharacterTextSplitter(
    chunk_size=1500,
    chunk_overlap=150,
    length_function=len,
    is_separator_regex=False,
)

pages = splitter.split_documents(documents)
print("Number of chunks = ", len(pages))
print(pages[3].page_content)


100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]

Number of chunks =  7

The cost of buying both new and second hand cars fell sharply over the past five years, a new survey has found.

South Korea's central bank cuts interest rates by a quarter percentage point to 3.5 in a bid to drive growth in the economy.

An auction of shares in Google, the web search engine which could be floated for as much as \$36bn, takes place on Friday.

Hewlett-Packard shares fall after disappointing third-quarter profits, while the firm warns the final quarter will also fall short of expectations.

One of the oldest textile operators on the Indian Ocean island of Mauritius last week shut seven factories and cut 900 jobs.

Chad asks the IMF for a loan to pay for looking after more than 100,000 refugees from conflict-torn Darfur in western Sudan.

The company running the Japanese nuclear plant hit by a fatal accident is to close its reactors for safety checks.

Trevor Baylis, the veteran inventor famous for creating the Freeplay clockwork radio, is planning




## Create a dataframe of all the chunks

In [8]:
df = documents2Dataframe(pages)
print(df.shape)
df.head()

(7, 3)


Unnamed: 0,text,source,chunk_id
0,"Reuters - Short-sellers, Wall Street's dwindli...",data/raw/input/ag_news/ag_news.txt,000c2851894849018c528551954ebed9
1,Forbes.com - After earning a PH.D. in Sociolog...,data/raw/input/ag_news/ag_news.txt,a92d82dfe6594b7d97fb0d5255f590f1
2,NEW YORK (Reuters) - The dollar tumbled broadl...,data/raw/input/ag_news/ag_news.txt,5c9e02945d5444808f8ea2a7875f73d1
3,Interest rates are trimmed to 7.5 by the South...,data/raw/input/ag_news/ag_news.txt,961b9911316e4e588f64e9695ee05830
4,A group led by the UAE's Etisalat plans to spe...,data/raw/input/ag_news/ag_news.txt,468b3cc22d40417d8f5de4b45def836d


## Extract Concepts

If regenerate is set to True then the dataframes are regenerated and Both the dataframes are written in the csv format so we dont have to calculate them again. 

        dfne = dataframe of edges

        df = dataframe of chunks


Else the dataframes are read from the output directory

In [9]:
## To regenerate the graph with LLM, set this to True
regenerate = True

if regenerate:
    concepts_list = df2Graph(df, model='zephyr:latest')
    dfg1 = graph2Df(concepts_list)
    if not os.path.exists(outputdirectory):
        os.makedirs(outputdirectory)
    
    dfg1.to_csv(outputdirectory/"graph.csv", sep="|", index=False)
    df.to_csv(outputdirectory/"chunks.csv", sep="|", index=False)
else:
    dfg1 = pd.read_csv(outputdirectory/"graph.csv", sep="|")

dfg1.replace("", np.nan, inplace=True)
dfg1.dropna(subset=["node_1", "node_2", 'edge'], inplace=True)
dfg1['count'] = 4 
## Increasing the weight of the relation to 4. 
## We will assign the weight of 1 when later the contextual proximity will be calculated.  
print(dfg1.shape)
dfg1.head()

100%|█████████████████████████████████████████████| 7/7 [00:52<00:00,  7.49s/it]

Possibly due to JSON Decode Error, 2 chunks have been skipped
(34, 5)





Unnamed: 0,node_1,node_2,edge,chunk_id,count
0,short-sellers,well-timed and occasionally controversial play...,"Carlyle Group, a private investment firm with ...",000c2851894849018c528551954ebed9,4
1,oil export flows,infrastructure,Authorities have halted oil export flows from ...,000c2851894849018c528551954ebed9,4
2,oil prices,wallets,"Tearaway world oil prices, toppling records an...",000c2851894849018c528551954ebed9,4
3,stock market,summer doldrums,Soaring crude prices plus worries about the ec...,000c2851894849018c528551954ebed9,4
4,retail money market mutual funds,assets,Assets of the nation's retail money market mut...,000c2851894849018c528551954ebed9,4


## Calculating contextual proximity

In [10]:
def contextual_proximity(df: pd.DataFrame) -> pd.DataFrame:
    ## Melt the dataframe into a list of nodes
    dfg_long = pd.melt(
        df, id_vars=["chunk_id"], value_vars=["node_1", "node_2"], value_name="node"
    )
    dfg_long.drop(columns=["variable"], inplace=True)
    # Self join with chunk id as the key will create a link between terms occuring in the same text chunk.
    dfg_wide = pd.merge(dfg_long, dfg_long, on="chunk_id", suffixes=("_1", "_2"))
    # drop self loops
    self_loops_drop = dfg_wide[dfg_wide["node_1"] == dfg_wide["node_2"]].index
    dfg2 = dfg_wide.drop(index=self_loops_drop).reset_index(drop=True)
    ## Group and count edges.
    dfg2 = (
        dfg2.groupby(["node_1", "node_2"])
        .agg({"chunk_id": [",".join, "count"]})
        .reset_index()
    )
    dfg2.columns = ["node_1", "node_2", "chunk_id", "count"]
    dfg2.replace("", np.nan, inplace=True)
    dfg2.dropna(subset=["node_1", "node_2"], inplace=True)
    # Drop edges with 1 count
    dfg2 = dfg2[dfg2["count"] != 1]
    dfg2["edge"] = "contextual proximity"
    return dfg2


dfg2 = contextual_proximity(dfg1)
dfg2.tail()

Unnamed: 0,node_1,node_2,chunk_id,count,edge
688,short-sellers,stock market,"000c2851894849018c528551954ebed9,a92d82dfe6594...",2,contextual proximity
735,stock market,non-opec oil exporters,"a92d82dfe6594b7d97fb0d5255f590f1,a92d82dfe6594...",2,contextual proximity
744,stock market,short-sellers,"000c2851894849018c528551954ebed9,a92d82dfe6594...",2,contextual proximity
790,united states,central square in lynn,"5374f6ca1e304e358499f64345b6b9aa,5374f6ca1e304...",2,contextual proximity
803,us securities regulators,non-opec oil exporters,"a92d82dfe6594b7d97fb0d5255f590f1,a92d82dfe6594...",2,contextual proximity


### Merge both the dataframes

In [11]:
dfg = pd.concat([dfg1, dfg2], axis=0)
dfg = (
    dfg.groupby(["node_1", "node_2"])
    .agg({"chunk_id": ",".join, "edge": ','.join, 'count': 'sum'})
    .reset_index()
)
dfg

Unnamed: 0,node_1,node_2,chunk_id,edge,count
0,central square in lynn,dell,"5374f6ca1e304e358499f64345b6b9aa,5374f6ca1e304...",contextual proximity,2
1,central square in lynn,gary winnick,"5374f6ca1e304e358499f64345b6b9aa,5374f6ca1e304...",contextual proximity,2
2,central square in lynn,gateway artisan block,"5374f6ca1e304e358499f64345b6b9aa,5374f6ca1e304...",The Gateway Artisan Block is a key area of Cen...,6
3,central square in lynn,global crossing,"5374f6ca1e304e358499f64345b6b9aa,5374f6ca1e304...",contextual proximity,2
4,central square in lynn,government,"5374f6ca1e304e358499f64345b6b9aa,5374f6ca1e304...",contextual proximity,2
...,...,...,...,...,...
66,united states,central square in lynn,"5374f6ca1e304e358499f64345b6b9aa,5374f6ca1e304...",contextual proximity,2
67,united states,government,5374f6ca1e304e358499f64345b6b9aa,Stein proposes that the government sell titles...,4
68,us securities regulators,non-opec oil exporters,"a92d82dfe6594b7d97fb0d5255f590f1,a92d82dfe6594...",contextual proximity,2
69,us trade deficit,oil costs,5c9e02945d5444808f8ea2a7875f73d1,relationship between the record US trade defic...,4


## Calculate the NetworkX Graph

In [12]:
nodes = pd.concat([dfg['node_1'], dfg['node_2']], axis=0).unique()
nodes.shape

(63,)

In [13]:
import networkx as nx
G = nx.Graph()

## Add nodes to the graph
for node in nodes:
    G.add_node(
        str(node)
    )

## Add edges to the graph
for index, row in dfg.iterrows():
    G.add_edge(
        str(row["node_1"]),
        str(row["node_2"]),
        title=row["edge"],
        weight=row['count']/4
    )

### Calculate communities for coloring the nodes

In [14]:
communities_generator = nx.community.girvan_newman(G)
top_level_communities = next(communities_generator)
next_level_communities = next(communities_generator)
communities = sorted(map(sorted, next_level_communities))
print("Number of Communities = ", len(communities))
print(communities)

Number of Communities =  23
[['assets', 'retail money market mutual funds'], ['back-to-school season', 'kids'], ['central square in lynn', 'dell', 'gary winnick', 'global crossing', 'government', 'kevin b. rollins', 'large loss', 'lynn', 'oil', 'quality distribution', 'russia', 'united states'], ['dollar', "economy's recovery"], ['domestic and corporate spending', 'japan'], ['e-mail', "internet's killer application"], ['elderly relatives', 'finances'], ['eurozone economy', 'slowdown later in the year'], ['gateway artisan block'], ['google', 'public offering'], ['google inc.', 'ipo', 'us securities regulators'], ['growth mutual funds', 'value-focused mutual funds'], ['heart disease drugs', 'nitromed inc.'], ['indian engineers', 'network rail'], ['infrastructure', 'oil export flows'], ['july', 'retail sales'], ['massachusetts', 'sales tax holiday'], ['non-opec oil exporters', 'opec', 'record crude prices', 'reuters', 'scorching oil prices', 'short-sellers', 'stock market', 'summer doldru

### Create a dataframe for community colors

In [15]:
import seaborn as sns
palette = "hls"

## Now add these colors to communities and make another dataframe
def colors2Community(communities) -> pd.DataFrame:
    ## Define a color palette
    p = sns.color_palette(palette, len(communities)).as_hex()
    random.shuffle(p)
    rows = []
    group = 0
    for community in communities:
        color = p.pop()
        group += 1
        for node in community:
            rows.append({"node": node, "color": color, "group": group})
    df_colors = pd.DataFrame(rows)
    return df_colors


colors = colors2Community(communities)
colors

Unnamed: 0,node,color,group
0,assets,#db5771,1
1,retail money market mutual funds,#db5771,1
2,back-to-school season,#57db6a,2
3,kids,#57db6a,2
4,central square in lynn,#abdb57,3
...,...,...,...
58,south african central bank,#bb57db,21
59,saudi arabia mobile phone licences,#db8157,22
60,uae's etisalat,#db8157,22
61,shell,#db57d9,23


### Add colors to the graph

In [16]:
for index, row in colors.iterrows():
    G.nodes[row['node']]['group'] = row['group']
    G.nodes[row['node']]['color'] = row['color']
    G.nodes[row['node']]['size'] = G.degree[row['node']]

In [17]:
from pyvis.network import Network

graph_output_directory = "./reports/graphs/graph.html"

net = Network(
    notebook=False,
    # bgcolor="#1a1a1a",
    cdn_resources="remote",
    height="900px",
    width="100%",
    select_menu=True,
    # font_color="#cccccc",
    filter_menu=False,
)

net.from_nx(G)
# net.repulsion(node_distance=150, spring_length=400)
net.force_atlas_2based(central_gravity=0.015, gravity=-31)
# net.barnes_hut(gravity=-18100, central_gravity=5.05, spring_length=380)
net.show_buttons(filter_=["physics"])

net.show(graph_output_directory)

                                       
[Select a Node by ID                                                    ]
Reset Selection
