# Sample the semantic structure graph feature for linguistic causal back-door intervention

- First, we process the QA pairs using the Stanza toolkit.
- Then, we encode them with CLIP and perform sampling using the k-means algorithm.

## 1. Deconstructing QA data

In [None]:
import stanza as snlp
import pandas as pd
import json

In [None]:
df = pd.read_csv("data/star/train.csv") # csv file of the train set
df = df[["question", "answer"]]

In [None]:
# init StanfordNLP pipeline
# snlp.download('en')
nlp = snlp.Pipeline()
data = df.to_dict(orient='records')

In [None]:
# Use StanfordNLP for text analysis and map construction
def process_text(text):
    doc = nlp(text)
    processed_data = []

    for sentence in doc.sentences:
        for word in sentence.words:
            # token = sentence.tokens[word.index - 1]
            processed_data.append({
                'text': word.text,
                'lemma': word.lemma,
                'upos': word.upos,
                'xpos': word.xpos,
                'head': word.head,  # head index (1-based)
                'deprel': word.deprel
            })
    
    return processed_data

def extract_components(question_data, answer_data):
    # Question type: Usually a question word
    question_type = next((word['text'] for word in question_data if word['upos'] in ['PRON', 'ADV', 'DET']), None)

    # Subject in question: 'nsubj' (noun subject) or 'nsubjpass' (passive noun subject) in dependency
    subject = next((word['text'] for word in question_data if word['deprel'] in ['nsubj', 'nsubjpass']), None)
    
    # All verbs in question
    verbs = list(set(word['text'] for word in question_data if word['upos'] == 'VERB'))

    # Object in question: all entities except the subject
    objs = list(set(word['text'] for word in question_data if word['deprel'] not in ['nsubj', 'nsubjpass'] and word['upos'] in ['NOUN', 'PROPN', 'PROPN', 'ADJ']))
    
    # Answers set
    answer = ' '.join([word['text'] for word in answer_data])

    # a_subject = next((word['text'] for word in answer_data if word['deprel'] in ['nsubj', 'nsubjpass']), None)
    
    # All verbs in the answer
    a_verbs = list(set(word['text'] for word in answer_data if word['upos'] == 'VERB'))

    # Object in the answer: all entities except the subject
    a_objs = list(set(word['text'] for word in answer_data if word['deprel'] not in ['nsubj', 'nsubjpass'] and word['upos'] in ['NOUN', 'PROPN']))
    
    components = {
        'question': ' '.join([word['text'] for word in question_data]),
        'question_type': question_type,
        'subject': subject,
        'verbs': verbs,
        'objects': objs,
        'answer': answer,
        'answer_verbs': a_verbs,
        'answer_objects': a_objs,
    }
    
    return components

In [None]:
data_components = []
for i in range(len(data)):
    question_data = process_text(data[i]['question'])
    answer_data = process_text(data[i]['answer'])
    components = extract_components(question_data, answer_data)
    data_components.append(components)
    print(f"\rProgress: {i+1}/{len(data)}", end='')

In [None]:
# Save the result as a JSON file
with open('data/star/causal_feature/data_components.json', 'w') as f:
    json.dump(data_components, f, indent=4)

## 2. encoding the components

In [None]:
def encode_text_with_clip(text, model, processor):
    inputs = processor(text=text, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)
    return text_features

def encode_components(components, model, processor):
    encoded_components = {}
    for key, value in components.items():
        if isinstance(value, list):
            # Encode list of texts
            text = ' '.join(value)
        else:
            # Encode single text
            text = value if value else ''
        encoded_components[key] = encode_text_with_clip(text, model, processor)
    return encoded_components

In [None]:
import torch
encoded_data_components = []
for i in range(len(data)):
    components = data_components[i]
    encoded_components = encode_components(components, clip_model, clip_processor)
    encoded_data_components.append(encoded_components)
    print(f"\rProgress: {i+1}/{len(data)}", end='')
# 将结果保存为npy文件
torch.save(encoded_data_components, 'data/star/causal_feature/encoded_data_components.npy')

## 3. Constructing the Graph

In [None]:
import torch_geometric
from torch_geometric.data import Data

def build_knowledge_graph(encoded_components):
    nodes = []
    edges = []

    # Define node features
    nodes.append(encoded_components['question'])
    nodes.append(encoded_components['question_type'])
    nodes.append(encoded_components['subject'])
    nodes.append(encoded_components['verbs'])
    nodes.append(encoded_components['objects'])
    nodes.append(encoded_components['answer'])
    nodes.append(encoded_components['answer_verbs'])
    nodes.append(encoded_components['answer_objects'])

    # Define edges
    edge_index = []

    # Add edges according to the specified relationships
    edge_index.extend([
        (0, 5),  # q--a
        (0, 1),  # q--q_type
        (0, 2),  # q--q_sub
        (0, 3),  # q--q_verb
        (0, 4),  # q--q_obj
        (2, 4),  # q_sub--q_obj
        (2, 3),  # q_sub--q_verb
        (4, 3),  # q_obj--q_verb
        (5, 6),  # a--a_verb
        (5, 7),  # a--a_obj
        (6, 7),  # a_verb--a_obj
        (2, 5)   # q_sub--a
    ])

    # Convert edge_index to tensor
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Convert nodes to tensor
    x = torch.cat(nodes, dim=0)

    # Create PyTorch Geometric Data object
    data = Data(x=x, edge_index=edge_index)

    return data

In [None]:
graphs = []
for i in range(len(encoded_data_components)):
    encoded_components = encoded_data_components[i]
    knowledge_graph = build_knowledge_graph(encoded_components)
    graphs.append(knowledge_graph)

## 4. Clustering graph

In [None]:
from sklearn.cluster import MiniBatchKMeans
import numpy as np

def kmean(x, k=512):
    x = torch.tensor(x)
    x = x.numpy()
    x = x.reshape([-1, 768*8])
    # Apply K-means algorithm
    print("feature sample:", x.shape[0])
    kmeans = MiniBatchKMeans(n_clusters=k, random_state=43, verbose=True).fit(x)
    print("clustering done")
    # Get the cluster center point
    cluster_centers = kmeans.cluster_centers_
    print("Get centers")
    # Gets the cluster label for each point
    labels = kmeans.labels_

    # Initializes a list to store the member characteristics of each cluster
    cluster_features = [x[labels == i] for i in range(k)]
    # Calculate the feature mean of each cluster
    cluster_means = [np.mean(cluster, axis=0) for cluster in cluster_features]
    print("Get mean")
    return cluster_centers, cluster_means
    
def cluster_knowledge_graphs(graphs, n_clusters=512, batch_size=100):
    # Flatten all node features into a single matrix
    all_features = []
    for graph in graphs:
        all_features.append(graph.x.reshape(1, -1))

    all_features = torch.cat(all_features, dim=0).numpy()

    # Perform MiniBatchKMeans clustering
    cluster_centers, cluster_means = kmean(all_features)
    cluster_centers = torch.from_numpy(cluster_centers).reshape(-1, 8, 768)
    cluster_means = torch.tensor(cluster_means).reshape(-1, 8, 768)
    # Reorganize graphs based on clustering results
    # Define edges
    edge_index = []

    # Add edges according to the specified relationships
    edge_index.extend([
        (0, 5),  # q--a
        (0, 1),  # q--q_type
        (0, 2),  # q--q_sub
        (0, 3),  # q--q_verb
        (0, 4),  # q--q_obj
        (2, 4),  # q_sub--q_obj
        (2, 3),  # q_sub--q_verb
        (4, 3),  # q_obj--q_verb
        (5, 6),  # a--a_verb
        (5, 7),  # a--a_obj
        (6, 7),  # a_verb--a_obj
        (2, 5)   # q_sub--a
    ])

    # Convert edge_index to tensor
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Create PyTorch Geometric Data object
    clustered_centers_graphs = Data(x=cluster_centers, edge_index=edge_index)
    clustered_means_graphs = Data(x=cluster_means, edge_index=edge_index)
    return clustered_centers_graphs, clustered_means_graphs

In [None]:
# Clustering knowledge graph
clustered_centers_graphs, clustered_means_graphs = cluster_knowledge_graphs(graphs, n_clusters=512, batch_size=100)

In [None]:
torch.save({"k_center": clustered_centers_graphs, "k_mean": clustered_means_graphs}, "data/star/causal_feature/qa_graphs.npy")