# MISATA-LLM: Causal DAG Extraction using Groq

## Novel Contribution
First framework using LLMs to extract **causal structure** (not just content) for synthetic data generation.

**Using Groq API** for fast inference with Llama 3.3 70B.

In [None]:
!pip install -q groq numpy pandas matplotlib networkx

In [None]:
import os
import json
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
import matplotlib.pyplot as plt
from dataclasses import dataclass
import time

# Set your Groq API key
# Get FREE API key from: https://console.groq.com/
GROQ_API_KEY = "YOUR_GROQ_API_KEY_HERE"  # <-- REPLACE THIS

# Or use Kaggle secrets
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    GROQ_API_KEY = user_secrets.get_secret("GROQ_API_KEY")
except:
    pass

os.environ['GROQ_API_KEY'] = GROQ_API_KEY
print(f"API Key configured: {'Yes' if GROQ_API_KEY != 'YOUR_GROQ_API_KEY_HERE' else 'No - please set your key'}")

## Part 1: Define Domain Description

In [None]:
DOMAIN_DESCRIPTION = """
Domain: Credit Card Fraud Detection

Variables:
- income: Customer's annual income (numerical)
- credit_limit: Maximum credit allowed (numerical)
- spending_rate: Average daily spending (numerical)
- transaction_amount: Individual transaction value (numerical)
- distance_from_home: How far the transaction is from home (numerical)
- fraud_probability: Likelihood of fraud (0 to 1)

Business Rules and Causal Relationships:
1. Higher income leads to higher credit limits (banks trust wealthy customers more)
2. Higher income leads to higher spending rates (wealthy people spend more)
3. Spending rate determines typical transaction amounts
4. Transactions far from home are more likely to be fraudulent
5. Very high transaction amounts relative to spending rate indicate potential fraud
6. Low-income customers with high spending have higher fraud rates
"""

print("Domain Description:")
print(DOMAIN_DESCRIPTION)

## Part 2: LLM Causal DAG Extraction using Groq

In [None]:
@dataclass
class CausalEdge:
    """A causal relationship between two variables."""
    source: str
    target: str
    coefficient: float
    description: str


def extract_causal_dag_groq(domain_description: str) -> List[CausalEdge]:
    """
    Extract causal DAG from domain description using Groq API.
    Uses Llama 3.3 70B for fast, high-quality inference.
    """
    from groq import Groq
    
    client = Groq(api_key=os.environ.get('GROQ_API_KEY'))
    
    # Build prompt without f-string issues
    prompt = """You are a causal inference expert. Extract causal relationships from this domain description.

For each causal relationship, provide:
- source: the cause variable (use snake_case, must be one of the variables listed)
- target: the effect variable (use snake_case, must be one of the variables listed)
- coefficient: estimated effect strength from -1.0 to 1.0 (positive = positive effect, negative = negative effect)
- description: brief explanation of why this causal relationship exists

IMPORTANT: 
- Only include DIRECT causal relationships (A causes B directly)
- Do NOT include correlations, only causation
- The graph must be acyclic (no circular dependencies)

Return ONLY a valid JSON array with objects containing: source, target, coefficient, description.
Do not include any markdown formatting or code blocks.

Domain Description:
""" + domain_description
    
    response = client.chat.completions.create(
        model="llama-3.3-70b-versatile",
        messages=[
            {
                "role": "system",
                "content": "You are a causal inference expert. You extract causal relationships from domain descriptions and return them as a JSON array. Only output valid JSON, no markdown."
            },
            {
                "role": "user",
                "content": prompt
            }
        ],
        temperature=0.1,
        max_tokens=2000
    )
    
    response_text = response.choices[0].message.content.strip()
    
    # Clean up response
    if response_text.startswith('```'):
        lines = response_text.split('\n')
        response_text = '\n'.join(lines[1:-1] if lines[-1] == '```' else lines[1:])
        if response_text.startswith('json'):
            response_text = response_text[4:]
    
    response_text = response_text.strip()
    
    # Parse JSON
    edges_data = json.loads(response_text)
    
    return [
        CausalEdge(
            source=e['source'],
            target=e['target'],
            coefficient=float(e['coefficient']),
            description=e['description']
        )
        for e in edges_data
    ]


# Extract DAG using Groq
print("Extracting causal DAG using Groq (Llama 3.3 70B)...")
print()

start_time = time.time()
try:
    causal_edges = extract_causal_dag_groq(DOMAIN_DESCRIPTION)
    extraction_time = time.time() - start_time
    print(f"Extraction completed in {extraction_time:.2f}s")
    LLM_SUCCESS = True
except Exception as e:
    print(f"Groq API error: {e}")
    print("\nFalling back to mock extraction...")
    causal_edges = [
        CausalEdge("income", "credit_limit", 0.7, "Higher income leads to higher credit limits"),
        CausalEdge("income", "spending_rate", 0.5, "Higher income leads to higher spending"),
        CausalEdge("spending_rate", "transaction_amount", 0.8, "Spending rate determines transactions"),
        CausalEdge("distance_from_home", "fraud_probability", 0.3, "Distance increases fraud risk"),
        CausalEdge("transaction_amount", "fraud_probability", 0.2, "High transactions increase fraud"),
        CausalEdge("income", "fraud_probability", -0.2, "Higher income decreases fraud")
    ]
    LLM_SUCCESS = False

print("\nExtracted Causal Relationships:")
print("=" * 60)
for edge in causal_edges:
    sign = "+" if edge.coefficient > 0 else "-"
    print(f"  {edge.source} → {edge.target} ({sign}{abs(edge.coefficient):.1f})")
    print(f"    Reason: {edge.description}")
    print()

## Part 3: Convert DAG to Agent Rules

In [None]:
class CausalAgentModel:
    """Agent model generated from LLM-extracted causal DAG."""
    
    def __init__(self, edges: List[CausalEdge]):
        self.edges = edges
        self.variables = self._extract_variables()
        self.adj_matrix = self._build_adjacency()
    
    def _extract_variables(self) -> List[str]:
        vars_set = set()
        for edge in self.edges:
            vars_set.add(edge.source)
            vars_set.add(edge.target)
        return sorted(list(vars_set))
    
    def _build_adjacency(self) -> Dict[str, List[Tuple[str, float]]]:
        adj = {v: [] for v in self.variables}
        for edge in self.edges:
            adj[edge.target].append((edge.source, edge.coefficient))
        return adj
    
    def topological_order(self) -> List[str]:
        order = []
        visited = set()
        
        def visit(node):
            if node in visited:
                return
            visited.add(node)
            for source, _ in self.adj_matrix[node]:
                visit(source)
            order.append(node)
        
        for v in self.variables:
            visit(v)
        
        return order


causal_model = CausalAgentModel(causal_edges)

print("Variables:", causal_model.variables)
print("\nTopological order:", causal_model.topological_order())
print("\nAdjacency (effects):")
for var, parents in causal_model.adj_matrix.items():
    if parents:
        print(f"  {var} ← {parents}")

## Part 4: Generate Synthetic Data

In [None]:
def generate_from_dag(model: CausalAgentModel, n_samples: int, seed: int = 42) -> pd.DataFrame:
    """Generate synthetic data following the causal DAG."""
    rng = np.random.default_rng(seed)
    order = model.topological_order()
    
    data = {}
    
    for var in order:
        parents = model.adj_matrix[var]
        
        if not parents:
            data[var] = rng.standard_normal(n_samples)
        else:
            value = np.zeros(n_samples)
            for src, coef in parents:
                value += coef * data[src]
            value += 0.3 * rng.standard_normal(n_samples)
            data[var] = value
    
    return pd.DataFrame(data)


print("Generating synthetic data from LLM-extracted causal DAG...")
n_samples = 10000

start = time.time()
df_dag = generate_from_dag(causal_model, n_samples)
gen_time = time.time() - start

print(f"Generated {len(df_dag):,} samples in {gen_time:.3f}s")
print(f"Throughput: {len(df_dag)/gen_time:,.0f} rows/sec")
print(f"\nSample data:")
print(df_dag.head())

## Part 5: Validate Causal Structure

In [None]:
print("=" * 60)
print("CAUSAL STRUCTURE VALIDATION")
print("=" * 60)

validation_results = []

for edge in causal_edges:
    if edge.source in df_dag.columns and edge.target in df_dag.columns:
        actual_corr = df_dag[edge.source].corr(df_dag[edge.target])
        expected_sign = "+" if edge.coefficient > 0 else "-"
        actual_sign = "+" if actual_corr > 0 else "-"
        match = expected_sign == actual_sign
        
        validation_results.append({
            'edge': f"{edge.source} → {edge.target}",
            'expected': edge.coefficient,
            'actual': actual_corr,
            'match': match
        })
        
        print(f"  {edge.source} → {edge.target}:")
        print(f"    Expected: {expected_sign}{abs(edge.coefficient):.2f}")
        print(f"    Actual:   {actual_corr:+.2f} {'✓' if match else '✗'}")
        print()

accuracy = sum(1 for r in validation_results if r['match']) / len(validation_results) if validation_results else 0
print(f"\nCausal Structure Accuracy: {accuracy:.0%}")

In [None]:
# Visualize DAG
try:
    import networkx as nx
    
    G = nx.DiGraph()
    for edge in causal_edges:
        G.add_edge(edge.source, edge.target, weight=edge.coefficient)
    
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
    
    nx.draw_networkx_nodes(G, pos, node_size=3000, node_color='lightblue', alpha=0.9)
    nx.draw_networkx_labels(G, pos, font_size=9, font_weight='bold')
    
    edge_colors = ['green' if G[u][v]['weight'] > 0 else 'red' for u, v in G.edges()]
    edge_widths = [abs(G[u][v]['weight']) * 4 for u, v in G.edges()]
    
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_widths, 
                          arrows=True, arrowsize=25, connectionstyle='arc3,rad=0.1')
    
    edge_labels = {(u, v): f"{G[u][v]['weight']:+.1f}" for u, v in G.edges()}
    nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=10)
    
    title = 'LLM-Extracted Causal DAG'
    if LLM_SUCCESS:
        title += ' (via Groq/Llama 3.3 70B)'
    else:
        title += ' (Mock - set GROQ_API_KEY for real extraction)'
    
    plt.title(title, fontsize=14, fontweight='bold')
    plt.figtext(0.5, 0.02, 'Green = positive effect, Red = negative effect', 
                ha='center', fontsize=10, style='italic')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('causal_dag_groq.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("\n✓ Saved causal_dag_groq.png")
except ImportError:
    print("networkx not available")

In [None]:
# Save results
results = {
    'method': 'MISATA-LLM (Groq)',
    'llm_model': 'llama-3.3-70b-versatile',
    'n_samples': n_samples,
    'n_causal_edges': len(causal_edges),
    'causal_accuracy': accuracy,
    'generation_time': gen_time,
    'llm_success': LLM_SUCCESS
}

pd.DataFrame([results]).to_csv('misata_llm_groq_results.csv', index=False)

edges_df = pd.DataFrame([
    {'source': e.source, 'target': e.target, 'coefficient': e.coefficient, 'description': e.description}
    for e in causal_edges
])
edges_df.to_csv('extracted_causal_dag.csv', index=False)

print("\n" + "=" * 60)
print("EXPERIMENT COMPLETE")
print("=" * 60)
print(f"\nKey Results:")
print(f"  LLM Success: {LLM_SUCCESS}")
print(f"  Causal Edges Extracted: {len(causal_edges)}")
print(f"  Causal Structure Accuracy: {accuracy:.0%}")
print(f"  Generation Speed: {n_samples/gen_time:,.0f} rows/sec")
print(f"\nNovel Contribution:")
print(f"  → LLM extracts CAUSAL STRUCTURE from natural language")
print(f"  → Structure is interpretable, editable, and validatable")
print(f"\nFiles generated:")
print(f"  - causal_dag_groq.png")
print(f"  - misata_llm_groq_results.csv")
print(f"  - extracted_causal_dag.csv")