# Generate Synthetic Memory Graph Data

This notebook generates synthetic memory graphs for training the MemoryGNN model.

We create realistic patient memory graphs at different Alzheimer's stages (0-7).

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch_geometric.data import Data
import pickle
import os

print("✓ Libraries imported successfully")

## Function to Generate One Patient Graph

In [None]:
def generate_patient_memory_graph(num_memories=20, stage=2, patient_id=0):
    """
    Generate a synthetic memory graph for one patient
    
    Node features (10 dimensions):
    - type_one_hot (5): person, place, event, skill, routine
    - recall_strength (1): 0-100 scaled to 0-1
    - emotional_weight (1): 0-1
    - importance (1): 1-10 scaled to 0-1
    - age_days (1): how old the memory is
    - access_freq (1): how often accessed
    """
    memory_types = {'person': 0.35, 'place': 0.20, 'event': 0.20, 'skill': 0.15, 'routine': 0.10}
    
    node_features = []
    initial_strengths = []
    
    for i in range(num_memories):
        # One-hot encode memory type
        type_probs = list(memory_types.values())
        mem_type_idx = np.random.choice(len(type_probs), p=type_probs)
        type_onehot = [0.0] * 5
        type_onehot[mem_type_idx] = 1.0
        
        # Recall strength decreases with Alzheimer's stage
        base_strength = np.random.uniform(60, 100) if stage <= 2 else np.random.uniform(30, 80)
        stage_decay = stage * 8
        recall_strength = max(10, base_strength - stage_decay + np.random.normal(0, 10))
        
        # Emotional memories are stronger
        emotional_weight = np.random.beta(2, 5)
        if emotional_weight > 0.8:
            recall_strength += 10
        
        importance = np.random.randint(3, 11) / 10.0
        age_days = np.random.exponential(scale=365 * 5) / (365 * 10)
        access_freq = np.random.beta(2, 8)
        
        features = type_onehot + [
            recall_strength / 100.0,
            emotional_weight,
            importance,
            min(age_days, 1.0),
            access_freq
        ]
        
        node_features.append(features)
        initial_strengths.append(recall_strength)
    
    x = torch.tensor(node_features, dtype=torch.float)
    
    # Generate edges
    edges = []
    edge_weights = []
    
    for i in range(num_memories):
        num_connections = np.random.randint(2, 6)
        targets = np.random.choice(num_memories, size=num_connections, replace=False)
        
        for j in targets:
            if i != j:
                edges.append([i, j])
                similarity = 1.0 - np.abs(node_features[i][5] - node_features[j][5])
                edge_weights.append(similarity)
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_weights, dtype=torch.float).unsqueeze(1)
    
    # Target: memory decay predictions
    y_node = []
    for i, strength in enumerate(initial_strengths):
        decay_rate = 0.01 * (stage + 1) + np.random.normal(0, 0.005)
        
        decay_30 = max(0, strength - (decay_rate * 30 * strength))
        decay_90 = max(0, strength - (decay_rate * 90 * strength))
        decay_180 = max(0, strength - (decay_rate * 180 * strength))
        
        if node_features[i][6] > 0.8:
            decay_30 = min(100, decay_30 + 5)
            decay_90 = min(100, decay_90 + 10)
            decay_180 = min(100, decay_180 + 15)
        
        y_node.append([decay_30 / 100.0, decay_90 / 100.0, decay_180 / 100.0])
    
    y_node = torch.tensor(y_node, dtype=torch.float)
    
    # Graph-level risk score
    avg_strength = np.mean(initial_strengths)
    risk_score = 1.0 - (avg_strength / 100.0)
    risk_score = min(1.0, max(0.0, risk_score + stage * 0.05))
    y_graph = torch.tensor([risk_score], dtype=torch.float)
    
    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y_node=y_node,
        y_graph=y_graph,
        patient_id=patient_id,
        stage=stage
    )
    
    return data

print("✓ Function defined")

## Test: Generate One Sample Graph

In [None]:
# Generate a sample graph
sample_graph = generate_patient_memory_graph(num_memories=15, stage=2, patient_id=0)

print(f"Number of nodes: {sample_graph.x.size(0)}")
print(f"Number of edges: {sample_graph.edge_index.size(1)}")
print(f"Node features shape: {sample_graph.x.shape}")
print(f"Target shape (decay): {sample_graph.y_node.shape}")
print(f"Risk score: {sample_graph.y_graph.item():.4f}")

## Generate Full Dataset (1000 patients)

In [None]:
def generate_dataset(num_patients=1000, output_dir='../data/synthetic'):
    os.makedirs(output_dir, exist_ok=True)
    
    train_data = []
    val_data = []
    test_data = []
    
    print(f"Generating {num_patients} synthetic patient memory graphs...")
    
    for i in range(num_patients):
        # Random stage with realistic distribution
        stage = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7], 
                                p=[0.05, 0.15, 0.25, 0.25, 0.15, 0.08, 0.05, 0.02])
        num_memories = np.random.randint(10, 41)
        
        data = generate_patient_memory_graph(num_memories=num_memories, stage=stage, patient_id=i)
        
        # 70-15-15 split
        rand = np.random.rand()
        if rand < 0.7:
            train_data.append(data)
        elif rand < 0.85:
            val_data.append(data)
        else:
            test_data.append(data)
        
        if (i + 1) % 100 == 0:
            print(f"Generated {i + 1}/{num_patients}")
    
    # Save
    with open(f'{output_dir}/train.pkl', 'wb') as f:
        pickle.dump(train_data, f)
    with open(f'{output_dir}/val.pkl', 'wb') as f:
        pickle.dump(val_data, f)
    with open(f'{output_dir}/test.pkl', 'wb') as f:
        pickle.dump(test_data, f)
    
    print(f"\nDataset saved to {output_dir}/")
    print(f"Train: {len(train_data)} | Val: {len(val_data)} | Test: {len(test_data)}")
    
    return train_data, val_data, test_data

# Generate the dataset
train_data, val_data, test_data = generate_dataset(num_patients=1000)

## Visualize Dataset Statistics

In [None]:
all_data = train_data + val_data + test_data

# Extract statistics
num_nodes = [data.x.size(0) for data in all_data]
num_edges = [data.edge_index.size(1) for data in all_data]
stages = [data.stage for data in all_data]
risk_scores = [data.y_graph.item() for data in all_data]

# Plot
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Nodes distribution
axes[0, 0].hist(num_nodes, bins=30, edgecolor='black', alpha=0.7)
axes[0, 0].set_title('Distribution of Number of Memories per Patient')
axes[0, 0].set_xlabel('Number of Memories')
axes[0, 0].set_ylabel('Frequency')

# Edges distribution
axes[0, 1].hist(num_edges, bins=30, edgecolor='black', alpha=0.7, color='orange')
axes[0, 1].set_title('Distribution of Memory Connections')
axes[0, 1].set_xlabel('Number of Connections')
axes[0, 1].set_ylabel('Frequency')

# Stage distribution
stage_counts = [stages.count(i) for i in range(8)]
axes[1, 0].bar(range(8), stage_counts, edgecolor='black', alpha=0.7, color='green')
axes[1, 0].set_title('Distribution of Alzheimer Stages')
axes[1, 0].set_xlabel('Stage')
axes[1, 0].set_ylabel('Number of Patients')
axes[1, 0].set_xticks(range(8))

# Risk score distribution
axes[1, 1].hist(risk_scores, bins=30, edgecolor='black', alpha=0.7, color='red')
axes[1, 1].set_title('Distribution of Risk Scores')
axes[1, 1].set_xlabel('Risk Score (0-1)')
axes[1, 1].set_ylabel('Frequency')

plt.tight_layout()
plt.show()

print(f"\nDataset Statistics:")
print(f"Average nodes per graph: {np.mean(num_nodes):.1f}")
print(f"Average edges per graph: {np.mean(num_edges):.1f}")
print(f"Average risk score: {np.mean(risk_scores):.3f}")

## Next Steps

Now that we have synthetic data, proceed to:
- **Notebook 02**: Train MemoryGNN model
- **Notebook 03**: Evaluate and visualize results