In [1]:
import json
import networkx as nx
import matplotlib.pyplot as plt
import pickle
import pandas as pd


In [7]:
all_fever_data = '/home/ashrafs/projects/dragon/data/fever/enriched_feverous_dev.jsonl'

train_fever_data = '/home/ashrafs/projects/dragon/data/fever/statement/train.statement.jsonl'
test_fever_data = '/home/ashrafs/projects/dragon/data/fever/statement/test.statement.jsonl'
dev_fever_data = '/home/ashrafs/projects/dragon/data/fever/statement/dev.statement.jsonl'

In [8]:
train_fever_graph = '/home/ashrafs/projects/dragon/data/fever/graph/train_graph.pickle'
test_fever_graph = '/home/ashrafs/projects/dragon/data/fever/graph/test_graph.pickle'
dev_fever_graph = '/home/ashrafs/projects/dragon/data/fever/graph/dev_graph.pickle'

In [9]:
def load_data(filename):
    data = []
    with open(filename, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

# Load the data
data_list = load_data(all_fever_data)

# Now, data_list contains all the data from the JSONL file
# Count the claims
number_of_claims = len(data_list)


In [10]:
number_of_claims

7890

In [11]:
def count_statements_per_claim(data_list):
    claim_data = []
    for data in data_list:
        claim_text = data.get('CLAIM', 'No Claim Found')
        entity_statements = data.get('ENTITY_STATEMENTS', {})

        # Count the number of statements for the current claim
        statement_count = sum(len(details.get('statements', [])) for _, details in entity_statements.items()) if entity_statements is not None else 0

        # Store both the claim text and the count in a tuple
        claim_data.append((claim_text, statement_count))

    return claim_data


In [12]:
# Assuming data_list contains your loaded data
claim_data = count_statements_per_claim(data_list)

# Convert to a pandas DataFrame
import pandas as pd
df = pd.DataFrame(claim_data, columns=['Claim', 'Statement Counts'])

# Now df contains both the claim text and the corresponding statement counts


In [18]:
# Filter the DataFrame to get rows where 'Statement Counts' is 0
claims_with_statements = df[df['Statement Counts'] > 0]

# Display these claims
print("Claims with Non-Zero Statements:")
print(len(claims_with_statements['Claim']))


Claims with Non-Zero Statements:
5395


In [25]:
filtered_data_list = [
    data for data in data_list 
    if data.get('ENTITY_STATEMENTS') is not None and 
       sum(len(details.get('statements', [])) for _, details in data.get('ENTITY_STATEMENTS', {}).items()) >= 10
]

# Splitting data_list into train, test, and dev sets (10 elements each)
train_data = filtered_data_list[:10]
test_data = filtered_data_list[10:20]
dev_data = filtered_data_list[20:30]


In [26]:

def is_fully_connected(G):
    return nx.is_connected(G.to_undirected()) if G.number_of_nodes() > 1 else False

def construct_graphs(data_subset, pickle_file_path):
    graphs_stats = []
    max_nodes = 200  # Maximum number of nodes allowed

    for data in data_subset:
        G = nx.DiGraph()  # Using a directed graph
        entity_statements = data.get('ENTITY_STATEMENTS', {})

        for _, details in entity_statements.items():
            for statement in details.get('statements', []):
                subject, relation, obj = statement
                # Check if the current number of nodes is less than the maximum allowed
                if G.number_of_nodes() < max_nodes:
                    # Add edge with relation as an edge attribute
                    G.add_edge(subject, obj, relation=relation)
                else:
                    break  # Stop adding edges/nodes once the maximum number is reached

        num_nodes = G.number_of_nodes()
        fully_connected = is_fully_connected(G)
        
        graphs_stats.append({'Claim': data.get('CLAIM'), 'Num_Nodes': num_nodes, 'Is_Fully_Connected': fully_connected, 'Graph': G})

    # Save the list of graphs to a file
    with open(pickle_file_path, "wb") as f:
        pickle.dump(graphs_stats, f)

# Example usage
construct_graphs(train_data, train_fever_graph)
construct_graphs(test_data, test_fever_graph)
construct_graphs(dev_data, dev_fever_graph)
