In [2]:
import pandas as pd
import json
import os
import networkx as nx
import importlib
import matplotlib.pyplot as plt

import helper_functions as hf
importlib.reload(hf)

<module 'helper_functions' from '/Users/tiril/Documents/IndividualProject/nuclear_repo/knowledge_graphs/helper_functions.py'>

In [3]:
# Get statistical properties of articles
name_dict = {
    'ARC': ['ARC'],
    'Babcock and Wilcox': ['Babcock'],
    'Berkeley': ['Berkeley'],
    'BWX': ['BWX'],
    'Elysium': ['Elysium'],
    'Flibe': ['Flibe'],
    'Framatome': ['Framatome'],
    'GE Hitachi': ['Hitachi', 'GEH'],
    'General Atomics': ['General Atomics'],
    'HolosGen': ['HolosGen', 'Holos'],
    'Holtec International': ['Holtec'],
    'Hyperion Power': ['Hyperion'],
    'Kairos Power': ['Kairos'],
    'Moltex Energy': ['Moltex'],
    'NANO Nuclear': ['NANO', 'NNE'],
    'NuScale': ['NuScale'],
    'Oak Ridge National Laboratory': ['Oak Ridge National Laboratory', 'ORNL'],
    'Oklo': ['Oklo'],
    'StarCore Nuclear': ['StarCore'],
    'TerraPower': ['TerraPower'],
    'Terrestial': ['Terrestial'],
    'ThorCon': ['ThorCon'],
    'Ultra Safe Nuclear Corporation': ['Ultra Safe Nuclear Corporation', 'USNC'],
    'Westinghouse': ['Westinghouse', 'WEC'],
    'X-Energy': ['X-energy']
}

filepath = 'data/articles_with_frequency.json'
with open(filepath, 'r') as file:
    articles = pd.read_json(filepath)

base_directory = 'data/unfiltered'

def get_articles(name, articles, base_directory, save=False):

    subset = []
    max_freq = 0
    tot_freq = 0
    url = None

    for i in range(len(articles)):
        article = articles.iloc[i]
        freq = int(article[name])
        tot_freq += freq
        if freq > max_freq:
            max_freq = freq
            url = articles.iloc[i]['url']
        if freq > 0:
            subset.append({
                'url': article['url'],
                'title': article['title'],
                'text': article['text'],
                'frequency': freq 
            })

    subset_length = len(subset)

    if save:
        filepath = f"{base_directory}/{name.lower().replace(' ', '_')}/{name.lower().replace(' ', '_')}.json"
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        with open(filepath, 'w') as file:
            json.dump(subset, file, indent=4, ensure_ascii=False)    
    
    print(f'{name} is present in {subset_length} articles ({(subset_length/len(articles))*100:.2f}%), Highest frequency: {max_freq}, Total frequency: {tot_freq} (in {url})')

# Get articles loop
for key, _ in name_dict.items():
    get_articles(key, articles, base_directory)

ARC is present in 33 articles (0.92%), Highest frequency: 23, Total frequency: 213 (in https://www.neimagazine.com/news/arc-clean-technology-signs-agreement-on-deployment-of-smrs-in-alberta-10708319/)
Babcock and Wilcox is present in 13 articles (0.36%), Highest frequency: 4, Total frequency: 20 (in https://www.newcivilengineer.com/latest/pm-announces-national-endeavour-to-build-nuclear-workforce-and-warns-of-china-threat-26-03-2024/)
Berkeley is present in 21 articles (0.58%), Highest frequency: 9, Total frequency: 47 (in https://www.neimagazine.com/news/first-concrete-box-of-radioactive-waste-safely-transferred-from-berkeley-npp-10531964/)
BWX is present in 54 articles (1.50%), Highest frequency: 19, Total frequency: 191 (in https://www.neimagazine.com/news/bwxt-to-evaluate-microreactor-deployment-in-wyoming-11146865/)
Elysium is present in 1 articles (0.03%), Highest frequency: 1, Total frequency: 1 (in https://www.neimagazine.com/news/more-funding-for-us-companies-under-does-gain-p

In [4]:
# Functions to get basic statistics of the KGs
def get_basic_stats(df):
    head_counts = df['h_ent'].value_counts()
    tail_counts = df['t_ent'].value_counts()
    relation_counts = df['relation'].value_counts()

    stats = {
        'head_counts': head_counts.to_dict(),
        'tail_counts': tail_counts.to_dict(),
        'relations': relation_counts.to_dict()
    }

    return stats

def print_stats(stats):
    print("Entity Counts:")
    print("{:<15} | {:>10} | {:>10} | {:>10} | {:>10}".format("Entity", "Head", "Tail", "Total", "Percent"))
    print("-" * 67)
    total_head = 0
    total_tail = 0
    entity_totals = []
    for entity in set(stats['head_counts'].keys()).union(stats['tail_counts'].keys()):
        head_count = stats['head_counts'].get(entity, 0)
        tail_count = stats['tail_counts'].get(entity, 0)
        entity_total = head_count + tail_count
        entity_totals.append((entity, head_count, tail_count, entity_total))
        total_head += head_count
        total_tail += tail_count
    grand_total = total_head + total_tail
    entity_totals.sort(key=lambda x: x[3], reverse=True)
    for entity, head, tail, total in entity_totals:
        percent = (total / grand_total) * 100
        print("{:<15} | {:>10} | {:>10} | {:>10} | {:>9.2f}%".format(entity, head, tail, total, percent))
    print("-" * 67)
    print("{:<15} | {:>10} | {:>10} | {:>10} | {:>9.2f}%".format("Total", total_head, total_tail, grand_total, 100.00))
    print("\nRelation Counts:")
    print("{:<20} | {:>10}".format("Relation", "Count"))
    print("-" * 32)
    total_relations = sum(stats['relations'].values())
    for relation, count in stats['relations'].items():
        print("{:<20} | {:>10}".format(relation, count))
    print("-" * 32)
    print("{:<20} | {:>10}".format("Total", total_relations))

In [51]:
# Final triplets
data = pd.read_excel('data/final_triplets.xlsx')

# All (i.e. raw) triplets
if False:
    with open('data/unfiltered/all_triplets.json', 'r') as file:
        data = json.load(file)

        transformed_data = [{
            'head': item['head']['word'],
            'h_ent': item['head']['entity'],
            'relation': item['relation'],
            'tail': item['tail']['word'],
            't_ent': item['tail']['entity']
        } for item in data]
    data = pd.DataFrame(transformed_data)

# Get final statistical properties
stats = get_basic_stats(data)
print_stats(stats)

Entity Counts:
Entity          |       Head |       Tail |      Total |    Percent
-------------------------------------------------------------------
LOC             |       4169 |       4020 |       8189 |     46.30%
ORG             |       1259 |       3911 |       5170 |     29.23%
PER             |       3053 |        441 |       3494 |     19.75%
MISC            |        362 |        470 |        832 |      4.70%
REACTOR         |          1 |          0 |          1 |      0.01%
FUEL            |          0 |          1 |          1 |      0.01%
SMR             |          0 |          1 |          1 |      0.01%
-------------------------------------------------------------------
Total           |       8844 |       8844 |      17688 |    100.00%

Relation Counts:
Relation             |      Count
--------------------------------
contains             |       4212
company              |       2183
nationality          |        787
company location     |        719
residence       

In [11]:
# Get basic statistics per KG at different prune levels
with open('data/triplets_no_cutoff/graphs.json', 'r') as file:
    results = json.load(file)

def count_nodes(G):
    counts = {
        'PER': 0,
        'ORG': 0,
        'LOC': 0,
        'MISC': 0,
        'Unknown': 0,
        'Other': 0,
        'Total': 0
    }

    for _, attrs in G.nodes(data=True):
        node_type = attrs.get('type', 'Other') 
        if node_type in counts:
            counts[node_type] += 1
            counts['Total'] += 1
        else:
            #print(node_type)
            counts['Total'] += 1
    
    return counts

statistics = {}
depths = [1, 2, 3, 4, 5]
for root in results:
    root_node = (root, 'ORG')
    try:
        G = hf.make_graph(results[root])
        statistics[root] = {}
        statistics[root]['all'] = count_nodes(G)

        for depth in depths:
            G_copy = G.copy()
            pruned_G = hf.prune_graph_by_depth(G_copy, root_node, depth, 'bidirectional')
            pruned_count = count_nodes(pruned_G)
            statistics[root][f'depth={depth}'] = pruned_count
    except Exception as e:
        print(f"Error processing {root}: {e}")
        continue

with open('results/basic_stats.json', 'w') as file:
    json.dump(statistics, file, indent=4, ensure_ascii=False)

Error processing Elysium: Source ('Elysium', 'ORG') is not in G
Error processing HolosGen: Source ('HolosGen', 'ORG') is not in G
Error processing Hyperion Power: Source ('Hyperion Power', 'ORG') is not in G
Error processing StarCore Nuclear: Source ('StarCore Nuclear', 'ORG') is not in G
Error processing Terrestial: Source ('Terrestial', 'ORG') is not in G


In [23]:
import matplotlib.pyplot as plt
import json

def plot_selected_histograms(statistics, depth, show_percent=False, show_other=False):
    skip_companies = ['Elysium', 'HolosGen', 'Hyperion Power', 'StarCore Nuclear', 'Terrestial']    
    entity_colors = {
        'LOC': '#ffb347',
        'ORG': '#87ceeb',
        'PER': '#90ee90',
        'MISC': '#dda0dd',
        'Other': '#d3d3d3'
    }
    
    fig, axes = plt.subplots(5, 4, figsize=(12, 16))
    axes = axes.flatten()
    
    # First pass to determine the global maximum for common y-axis
    max_frequency = 0
    for company, data in statistics.items():
        if company in skip_companies:
            continue
        
        counts = data[depth]
        known_tags_sum = sum(counts.get(tag, 0) for tag in ['PER', 'ORG', 'LOC', 'MISC'])
        total_entities = sum(counts.values())
        other_entities = total_entities - known_tags_sum

        for tag in ['PER', 'ORG', 'LOC', 'MISC']:
            max_frequency = max(max_frequency, counts.get(tag, 0)) + 0.2 # Increase to decrease max height of bars

    index = 0
    for company, data in statistics.items():
        if company in skip_companies:
            continue
        if index >= 20:
            break

        counts = data[depth]
        entity_tags = ['PER', 'ORG', 'LOC', 'MISC']
        frequencies = [counts.get(tag, 0) for tag in entity_tags]
        total_entities = sum(frequencies)  # Calculate total number of entities for percentages

        # Calculate "Other" entities
        if show_other:
            known_tags_sum = sum(frequencies)
            other_entities = total_entities - known_tags_sum
            frequencies.append(other_entities)
            entity_tags.append('Other')

        if total_entities > 0:
            percentages = [(freq / total_entities) * 100 for freq in frequencies]  # Calculate percentages
        else:
            percentages = [0] * len(frequencies)

        colors = [entity_colors[tag] for tag in entity_tags]

        axes[index].bar(entity_tags, frequencies, color=colors)
        axes[index].set_title(company)
        axes[index].set_ylim(0, max_frequency)  # Use common y-axis height
        
        if show_percent:
            for j, (freq, percent) in enumerate(zip(frequencies, percentages)):
                axes[index].text(j, freq, f'{percent:.0f}%', ha='center', va='bottom', fontsize=10, color='black')
        else:
            for j, freq in enumerate(frequencies):
                axes[index].text(j, freq , str(freq), ha='center', va='bottom', fontsize=10, color='black')
        
        index += 1

    plt.tight_layout()
    plt.savefig(f'results/histograms_{"with_percent" if show_percent else "with_numbers"}_{"with_Other" if show_other else ""}_{depth}.png', bbox_inches='tight')
    plt.close(fig)

In [24]:
# Create node distribution histographs
with open('data/triplets_no_cutoff/graphs.json', 'r') as file:
    results = json.load(file)
with open('results/basic_stats.json', 'r') as file:
    statistics = json.load(file)

depth = 'all'
plot_selected_histograms(statistics, depth, show_percent=False, show_other=True)

if False:
    for depth in ['depth=1','depth=2','depth=3','all']:
        for show_percent in [True, False]:
            plot_selected_histograms(statistics, depth, show_percent=show_percent)

In [4]:
# Create grid of KGs for different depths
with open('data/triplets_no_cutoff/graphs.json', 'r') as file:
    results = json.load(file)
    
def generate_kg_grid_for_depth(results, save_directory, direction, max_depth, skip_companies):
    fig, axes = plt.subplots(5, 4, figsize=(12, 16))
    axes = axes.flatten()
    
    index = 0
    for root in results:
        if root in skip_companies:
            continue
        if index >= 20:
            break
        
        try:
            root_node = (root, 'ORG')
            G = hf.make_graph(results[root])
            G_pruned = hf.prune_graph_by_depth(G, root_node, max_depth, direction)
            num_nodes = G_pruned.number_of_nodes()
            
            ax = axes[index]
            pos = nx.spring_layout(G_pruned, k=2, iterations=100)

            entity_colors = {
                'LOC': '#ffb347',  # orange
                'ORG': '#87ceeb',  # light blue
                'PER': '#90ee90',  # light green
                'MISC': '#dda0dd',  # plum
            }

            node_colors = []
            node_sizes = []
            labels = {}

            for node in G_pruned.nodes():
                entity_name, entity_type = node
                labels[node] = entity_name

                if node == root_node:
                    node_colors.append('#d90000')  # Root node color (red)
                    node_sizes.append(100)  # Root node size (also dynamic)
                else:
                    node_colors.append(entity_colors.get(entity_type, '#c08aed'))  # purple for any undefined types
                    node_sizes.append(100)

            nx.draw(G_pruned, pos, node_size=[s * 0.5 for s in node_sizes], node_color=node_colors,
                    font_size=0, arrows=True, arrowsize=5, edge_color='gray', width=0.2, ax=ax)
            
            ax.set_title(f"{root}", fontsize=12)
            index += 1

        except Exception as e:
            print(f"Error processing {root} at depth {max_depth}: {e}")
            continue

    plt.tight_layout()
    #plt.suptitle(f"Knowledge Graphs at Depth {max_depth} ({direction})", fontsize=20, y=1.02)
    
    if save_directory:
        filepath = os.path.join(save_directory, f"kg_grid_depth_{max_depth}_{direction}.png")
        plt.savefig(filepath, format='png', bbox_inches='tight')
    
    plt.close(fig)

save_directory = 'results'
direction = 'bidirectional_and'
max_depths = [-1]  # Choose the depth level you want to visualize
skip_companies = ['Elysium', 'HolosGen', 'Hyperion Power', 'StarCore Nuclear', 'Terrestial']

for max_depth in max_depths:
    generate_kg_grid_for_depth(results, save_directory, direction, max_depth, skip_companies)

In [44]:
# Get advanced KG statistics
filepath = 'data/triplets_no_cutoff/graphs.json'

with open(filepath, 'r') as file:
    results = json.load(file)

statistics = {}

for root in results:
    root_node = (root, 'ORG')
    result = results[root]
    
    # Skip empty graphs
    if not result:
        continue

    # Otherwise, continue on
    G = hf.make_graph(result)
        
    total_degree = G.degree(root_node)
    in_degree = G.in_degree(root_node)
    out_degree = G.out_degree(root_node)
    in_div_out = in_degree/out_degree if out_degree > 0 else 0
    
    graph_stats = {
        'total_degree': total_degree,
        'in_degree': in_degree,
        'out_degree': out_degree,
        'in/out': in_div_out
    }
    
    statistics[root] = graph_stats

with open('results/degree_stats.json', 'w') as outfile:
    json.dump(statistics, outfile, indent=4)