In [None]:
# #G-Retriever -  GraphRAG without GAT
# Usese a Dummy GNN
from sentence_transformers import SentenceTransformer
import torch
from torch_geometric.nn.models import GRetriever
from torch_geometric.nn.nlp import LLM
from torch.nn import Module
import warnings

# Suppress specific warnings (e.g., CPU usage warnings)
warnings.filterwarnings("ignore", category=UserWarning, module="torch_geometric.nn.nlp.llm")

# Define DummyGNN class with out_channels
class DummyGNN(Module):
    def __init__(self, out_channels):
        super().__init__()
        self.out_channels = out_channels  # Define the expected attribute

    def forward(self, x, edge_index=None, edge_attr=None):
        return x  # Pass through input without modification

# Initialize Sentence Transformer for embeddings
model = SentenceTransformer('all-MPNet-base-v2')

# Expanded node descriptions with more complexity
node_descriptions = [
    "Climate Change Policies: Policies aimed at reducing greenhouse gas emissions and mitigating climate change effects.",
    "Renewable Energy Adoption: Transition to renewable energy sources like solar, wind, and hydroelectric power.",
    "Fossil Fuel Consumption: Use of fossil fuels such as coal, oil, and natural gas for energy.",
    "Electric Vehicles: Vehicles powered by electric motors using energy stored in rechargeable batteries.",
    "Carbon Emissions: Release of carbon dioxide into the atmosphere from various sources.",
    "Air Pollution: Contamination of air by harmful substances including gases and particulates.",
    "Public Health: Health outcomes of populations influenced by environmental and social factors.",
    "Economic Growth: Increase in the market value of goods and services produced by an economy over time.",
    "Job Creation: Generation of new employment opportunities in an economy.",
    "Energy Efficiency: Using less energy to perform the same task, reducing energy waste.",
    "Energy Policy: Government policy regarding the production, distribution, and consumption of energy.",
    "International Agreements: Agreements between nations to cooperate on issues like climate change.",
    "Technological Innovation: Development of new technologies or improvements to existing ones.",
    "Infrastructure Investment: Allocation of funds to build or improve physical infrastructure.",
    "Sustainable Agriculture: Farming practices that meet current food needs without compromising future generations.",
    "Deforestation: Removal of a forest or stand of trees from land which is then converted to non-forest use.",
    "Biodiversity Loss: Decline in the number and variety of species in a given area.",
    "Climate Refugees: People forced to leave their homes due to the impacts of climate change.",
    "Water Scarcity: Lack of sufficient available water resources to meet the demands within a region.",
    "Disaster Risk Management: Strategies to reduce the damage caused by natural disasters.",
    # Add more nodes if needed to increase complexity
]

# Generate embeddings for the nodes
node_features = torch.tensor(model.encode(node_descriptions), dtype=torch.float)

# Expanded relationships (edge_index) to create a more complex graph
edge_index = torch.tensor([
    # Source nodes (from)
    [0, 0, 0, 0, 0,    1, 1, 1,     2, 2, 2,    3, 3,   4, 4,   5, 5,    6, 7, 7, 7,    8, 8,   9, 9,   10, 10, 11, 12, 12, 13, 14, 15, 16, 17, 18],
    # Target nodes (to)
    [1, 3, 9, 11, 19,  2, 8, 13,    4, 5, 6,    2, 12,  5, 6,   6, 7,    7,8,13,14,    7,10,   2,10,   0,  6,  0, 1,   3, 2, 15,16, 17,6,  6]
], dtype=torch.long)

# Batch vector for a single graph
batch = torch.zeros(len(node_descriptions), dtype=torch.long)

# Check if GPU is available and set device accordingly
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize DummyGNN with out_channels matching the embedding size
dummy_gnn = DummyGNN(out_channels=node_features.size(1)).to(device)

# Initialize LLM
llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7)

# Move LLM to the appropriate device
llm.to(device)

# Initialize G-Retriever with Dummy GNN
g_retriever = GRetriever(llm=llm, gnn=dummy_gnn, mlp_out_channels=4096)

# Complex query without additional context
query = [
    "Evaluate the potential long-term economic and environmental impacts of shifting from fossil fuel consumption to renewable energy adoption, considering factors such as job creation, technological innovation, energy policy, and public health. Additionally, discuss how this transition affects international agreements and climate change policies, and identify potential challenges in infrastructure investment and sustainable agriculture."
]

# Move data to the correct device
node_features = node_features.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)

# Perform inference without additional context
with torch.no_grad():
    subgraph_context = g_retriever.inference(
        question=query,
        x=node_features,          # Node embeddings
        edge_index=edge_index,    # Complex edge connections
        batch=batch,
        max_out_tokens=1024
    )

# Print results
print("Subgraph context retrieved (Without GNN):", subgraph_context)