"""
Complete NV Center ODMR Prediction using Graph Neural Networks
with Synthetic Data Generation via QuTiP

This implementation:
1. Generates synthetic ODMR data using QuTiP with time-dependent Hamiltonian
2. Creates graph representations of NV center environments
3. Trains a GNN to predict ODMR spectra from spin bath configurations
"""


In [3]:

import numpy as np
import torch
import torch.nn as nn
from torch_geometric.data import Data, Dataset
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
from qutip import *
import warnings
warnings.filterwarnings('ignore')

In [4]:


# ============================================================================
# PART 1: QuTiP-based Synthetic ODMR Data Generation
# ============================================================================

class NVCenterSimulator:
    """Simulate NV center ODMR using QuTiP with dissipation"""
    
    def __init__(self):
        # Physical constants (in MHz unless noted)
        self.D = 2870.0  # Zero-field splitting (MHz)
        self.gamma_e = 28.0  # Electron gyromagnetic ratio (MHz/mT)
        self.gamma_n = 0.0031  # 13C nuclear gyromagnetic ratio (MHz/mT)
        self.gamma_N = -0.0045  # 14N nuclear gyromagnetic ratio (MHz/mT)
        
        # Operators for S=1 electron spin
        self.Sx = jmat(1, 'x')
        self.Sy = jmat(1, 'y')
        self.Sz = jmat(1, 'z')
        
        # Operators for I=1/2 nuclear spins (13C)
        self.Ix = 0.5 * sigmax()
        self.Iy = 0.5 * sigmay()
        self.Iz = 0.5 * sigmaz()
        
        # Operators for I=1 nitrogen spin (14N)
        self.INx = jmat(1, 'x')
        self.INy = jmat(1, 'y')
        self.INz = jmat(1, 'z')
    
    def build_hamiltonian(self, B_field, nuclear_config, A_hyperfine, strain=0.0):
        """
        Build NV center Hamiltonian with nuclear spins
        
        Args:
            B_field: Magnetic field in mT (applied along NV axis)
            nuclear_config: List of nuclear spin configurations
            A_hyperfine: List of hyperfine coupling strengths (MHz)
            strain: Strain splitting (MHz)
        """
        # Number of 13C spins
        n_nuclear = len(nuclear_config)
        
        # Electron spin operators in full Hilbert space
        # Start with electron spin (3 levels)
        dims = [3] + [2] * n_nuclear  # 3 for electron, 2 for each nuclear spin
        
        # Zero-field splitting term: D * Sz^2
        H0 = self.D * self.Sz * self.Sz
        
        # Add strain if present
        if strain != 0:
            H0 += strain * (self.Sx * self.Sx - self.Sy * self.Sy)
        
        # Zeeman term: gamma_e * B * Sz
        H0 += self.gamma_e * B_field * self.Sz
        
        # Expand to full Hilbert space
        ops = [H0] + [qeye(2)] * n_nuclear
        H = tensor(ops)
        
        # Add hyperfine interactions for each nuclear spin
        for i, (A_parallel, A_perp) in enumerate(A_hyperfine):
            # Parallel component: A_parallel * Sz * Iz
            ops_parallel = [self.Sz] + [qeye(2)] * n_nuclear
            ops_parallel[i + 1] = self.Iz
            H += A_parallel * tensor(ops_parallel)
            
            # Perpendicular components: A_perp * (Sx*Ix + Sy*Iy)
            ops_x = [self.Sx] + [qeye(2)] * n_nuclear
            ops_x[i + 1] = self.Ix
            H += A_perp * tensor(ops_x)
            
            ops_y = [self.Sy] + [qeye(2)] * n_nuclear
            ops_y[i + 1] = self.Iy
            H += A_perp * tensor(ops_y)
            
            # Nuclear Zeeman term
            ops_nz = [qeye(3)] + [qeye(2)] * n_nuclear
            ops_nz[i + 1] = self.Iz
            H += self.gamma_n * B_field * tensor(ops_nz)
        
        return H
    
    def odmr_signal(self, freq_range, B_field, nuclear_config, A_hyperfine, 
                    T1=1e6, T2=1e3, rabi_freq=10.0, pulse_time=1000):
        """
        Simulate ODMR signal with dissipation
        
        Args:
            freq_range: Array of microwave frequencies (MHz)
            B_field: Magnetic field (mT)
            nuclear_config: Nuclear spin states
            A_hyperfine: Hyperfine couplings
            T1: Longitudinal relaxation time (ns)
            T2: Transverse relaxation time (ns)
            rabi_freq: Rabi frequency of MW drive (MHz)
            pulse_time: MW pulse duration (ns)
        """
        n_nuclear = len(nuclear_config)
        dims = [3] + [2] * n_nuclear
        
        # Initial state: |ms=0> ⊗ |nuclear_config>
        psi0_list = [basis(3, 1)]  # ms=0 state
        for config in nuclear_config:
            psi0_list.append(basis(2, config))
        psi0 = tensor(psi0_list)
        
        # Build static Hamiltonian
        H_static = self.build_hamiltonian(B_field, nuclear_config, A_hyperfine)
        
        # Measurement operator: population in ms=±1 states
        proj_minus1 = basis(3, 0) * basis(3, 0).dag()
        proj_plus1 = basis(3, 2) * basis(3, 2).dag()
        ops_measure = [proj_minus1 + proj_plus1] + [qeye(2)] * n_nuclear
        measure_op = tensor(ops_measure)
        
        odmr_spectrum = []
        
        for freq in freq_range:
            # Time-dependent Hamiltonian: H = H_static + Omega * cos(2*pi*freq*t) * Sx
            # Drive term
            ops_drive = [self.Sx] + [qeye(2)] * n_nuclear
            H_drive = rabi_freq * tensor(ops_drive)
            
            # For simplicity, use rotating wave approximation
            # Effective Hamiltonian in rotating frame
            detuning = freq - (self.D + self.gamma_e * B_field)
            ops_det = [self.Sz] + [qeye(2)] * n_nuclear
            H_eff = H_static - freq * tensor(ops_det) + H_drive
            
            # Collapse operators for dissipation
            # T1 relaxation
            gamma1 = 1.0 / T1
            ops_c1a = [basis(3, 1) * basis(3, 0).dag()] + [qeye(2)] * n_nuclear
            ops_c1b = [basis(3, 1) * basis(3, 2).dag()] + [qeye(2)] * n_nuclear
            c_ops = [
                np.sqrt(gamma1) * tensor(ops_c1a),
                np.sqrt(gamma1) * tensor(ops_c1b)
            ]
            
            # T2 dephasing (simplified)
            gamma2 = 1.0 / T2
            ops_deph = [self.Sz] + [qeye(2)] * n_nuclear
            c_ops.append(np.sqrt(gamma2) * tensor(ops_deph))
            
            # Time evolution
            times = np.linspace(0, pulse_time, 50)
            
            try:
                result = mesolve(H_eff, psi0, times, c_ops, [measure_op])
                
                # ODMR signal is population in excited states after pulse
                signal = result.expect[0][-1]
                odmr_spectrum.append(signal)
            except:
                # If simulation fails, use approximate value
                odmr_spectrum.append(0.5)
        
        return np.array(odmr_spectrum)
    
    def generate_sample(self, n_nuclear=3, B_field=None, add_noise=True):
        """Generate a single ODMR sample with random nuclear configuration"""
        
        # Random magnetic field if not specified
        if B_field is None:
            B_field = np.random.uniform(0.5, 5.0)  # 0.5-5 mT
        
        # Random nuclear spin configuration
        nuclear_config = np.random.randint(0, 2, n_nuclear)
        
        # Random nuclear positions (in Angstroms relative to NV)
        nuclear_positions = []
        A_hyperfine = []
        
        for _ in range(n_nuclear):
            # Distance from NV center (1-10 Angstroms)
            r = np.random.uniform(1.0, 10.0)
            
            # Random angle
            theta = np.random.uniform(0, np.pi)
            phi = np.random.uniform(0, 2 * np.pi)
            
            pos = np.array([
                r * np.sin(theta) * np.cos(phi),
                r * np.sin(theta) * np.sin(phi),
                r * np.cos(theta)
            ])
            nuclear_positions.append(pos)
            
            # Hyperfine coupling (scales as 1/r^3 for dipolar)
            A_parallel = 50.0 / (r ** 3) + np.random.normal(0, 1)
            A_perp = 25.0 / (r ** 3) + np.random.normal(0, 0.5)
            A_hyperfine.append((A_parallel, A_perp))
        
        # Frequency range around resonance
        center_freq = self.D + self.gamma_e * B_field
        freq_range = np.linspace(center_freq - 20, center_freq + 20, 50)
        
        # Random coherence times
        T1 = np.random.uniform(1e5, 1e7)  # 100 us to 10 ms
        T2 = np.random.uniform(1e2, 1e4)  # 100 ns to 10 us
        
        # Simulate ODMR
        odmr = self.odmr_signal(
            freq_range, B_field, nuclear_config, A_hyperfine,
            T1=T1, T2=T2
        )
        
        # Add experimental noise
        if add_noise:
            noise_level = 0.02
            odmr += np.random.normal(0, noise_level, len(odmr))
        
        return {
            'nuclear_positions': np.array(nuclear_positions),
            'nuclear_config': nuclear_config,
            'A_hyperfine': A_hyperfine,
            'B_field': B_field,
            'freq_range': freq_range,
            'odmr_spectrum': odmr,
            'T1': T1,
            'T2': T2
        }



In [5]:

# ============================================================================
# PART 2: Graph Neural Network Architecture
# ============================================================================

class NVGraphDataset(Dataset):
    """PyTorch Geometric Dataset for NV center graphs"""
    
    def __init__(self, num_samples=1000, n_nuclear=3):
        super().__init__()
        self.num_samples = num_samples
        self.n_nuclear = n_nuclear
        self.simulator = NVCenterSimulator()
        self.data_list = []
        
        print(f"Generating {num_samples} synthetic ODMR samples...")
        for i in range(num_samples):
            if i % 100 == 0:
                print(f"  Generated {i}/{num_samples}")
            sample = self.simulator.generate_sample(n_nuclear=n_nuclear)
            graph = self._create_graph(sample)
            self.data_list.append(graph)
        print("Dataset generation complete!")
    
    def _create_graph(self, sample):
        """Convert NV sample to graph representation"""
        nuclear_pos = sample['nuclear_positions']
        n_nuclear = len(nuclear_pos)
        
        # Node features
        # NV center node: [0, 0, 0, B_field, 1] (position + field + type)
        nv_node = torch.tensor([[0., 0., 0., sample['B_field'], 1.0]])
        
        # Nuclear nodes: [x, y, z, spin_state, 0] (position + state + type)
        nuclear_nodes = torch.cat([
            torch.tensor(nuclear_pos, dtype=torch.float32),
            torch.tensor(sample['nuclear_config'], dtype=torch.float32).unsqueeze(1),
            torch.zeros(n_nuclear, 1)
        ], dim=1)
        
        x = torch.cat([nv_node, nuclear_nodes], dim=0)
        
        # Edges: NV to all nuclear spins (bidirectional)
        edges = []
        edge_features = []
        
        for i in range(n_nuclear):
            # NV to nuclear
            edges.append([0, i + 1])
            
            # Nuclear to NV
            edges.append([i + 1, 0])
            
            # Edge features: [distance, A_parallel, A_perp]
            distance = np.linalg.norm(nuclear_pos[i])
            A_par, A_perp = sample['A_hyperfine'][i]
            edge_feat = [distance, A_par, A_perp]
            edge_features.extend([edge_feat, edge_feat])
        
        # Nuclear-nuclear edges (if close enough)
        for i in range(n_nuclear):
            for j in range(i + 1, n_nuclear):
                dist = np.linalg.norm(nuclear_pos[i] - nuclear_pos[j])
                if dist < 5.0:  # Angstroms
                    edges.append([i + 1, j + 1])
                    edges.append([j + 1, i + 1])
                    
                    # Dipolar coupling ~ 1/r^3
                    coupling = 10.0 / (dist ** 3)
                    edge_feat = [dist, coupling, 0.0]
                    edge_features.extend([edge_feat, edge_feat])
        
        edge_index = torch.tensor(edges, dtype=torch.long).t()
        edge_attr = torch.tensor(edge_features, dtype=torch.float32)
        
        # Target: ODMR spectrum
        y = torch.tensor(sample['odmr_spectrum'], dtype=torch.float32)
        
        # Additional info
        freq_range = torch.tensor(sample['freq_range'], dtype=torch.float32)
        coherence = torch.tensor([sample['T1'], sample['T2']], dtype=torch.float32)
        
        return Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=y,
            freq_range=freq_range,
            coherence=coherence
        )
    
    def len(self):
        return self.num_samples
    
    def get(self, idx):
        return self.data_list[idx]


class NVMessagePassing(MessagePassing):
    """Custom message passing for NV-nuclear spin interactions"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        
        self.message_mlp = nn.Sequential(
            nn.Linear(2 * in_channels + 3, out_channels),  # +3 for edge features
            nn.ReLU(),
            nn.Linear(out_channels, out_channels),
            nn.LayerNorm(out_channels)
        )
        
        self.update_mlp = nn.Sequential(
            nn.Linear(in_channels + out_channels, out_channels),
            nn.ReLU(),
            nn.LayerNorm(out_channels)
        )
    
    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)
    
    def message(self, x_i, x_j, edge_attr):
        msg = torch.cat([x_i, x_j, edge_attr], dim=-1)
        return self.message_mlp(msg)
    
    def update(self, aggr_out, x):
        return self.update_mlp(torch.cat([x, aggr_out], dim=-1))


class NVGNN(nn.Module):
    """Graph Neural Network for ODMR spectrum prediction"""
    
    def __init__(self, node_features=5, edge_features=3, hidden_dim=128, 
                 num_layers=4, spectrum_length=50):
        super().__init__()
        
        self.node_encoder = nn.Sequential(
            nn.Linear(node_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_features, hidden_dim),
            nn.ReLU()
        )
        
        self.conv_layers = nn.ModuleList([
            NVMessagePassing(hidden_dim, hidden_dim)
            for _ in range(num_layers)
        ])
        
        # Attention mechanism for aggregation
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
        
        # Spectrum prediction head
        self.spectrum_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, spectrum_length)
        )
        
        # Coherence time prediction head
        self.coherence_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 2)  # T1, T2
        )
    
    def forward(self, data):
        x = self.node_encoder(data.x)
        edge_attr = data.edge_attr
        
        # Message passing
        for conv in self.conv_layers:
            x_new = conv(x, data.edge_index, edge_attr)
            x = x + x_new  # Residual connection
        
        # Extract NV center embedding (first node)
        batch_size = data.batch.max().item() + 1
        nv_embeddings = []
        
        for i in range(batch_size):
            mask = data.batch == i
            nv_idx = torch.where(mask)[0][0]  # First node in each graph
            nv_embeddings.append(x[nv_idx])
        
        nv_embedding = torch.stack(nv_embeddings)
        
        # Predict ODMR spectrum
        spectrum = self.spectrum_head(nv_embedding)
        
        # Predict coherence times
        coherence = self.coherence_head(nv_embedding)
        coherence = torch.exp(coherence)  # Ensure positive
        
        return {
            'spectrum': spectrum,
            'coherence': coherence
        }



In [6]:

# ============================================================================
# PART 3: Training and Evaluation
# ============================================================================

def train_model(model, train_loader, val_loader, epochs=50, device='cpu'):
    """Train the NV GNN model"""
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': []}
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            predictions = model(batch)
            
            # Loss: MSE for spectrum + MSE for coherence
            spectrum_loss = nn.functional.mse_loss(predictions['spectrum'], batch.y)
            coherence_loss = nn.functional.mse_loss(
                torch.log(predictions['coherence']), 
                torch.log(batch.coherence)
            )
            
            loss = spectrum_loss + 0.1 * coherence_loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_losses.append(loss.item())
        
        # Validation
        model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                predictions = model(batch)
                
                spectrum_loss = nn.functional.mse_loss(predictions['spectrum'], batch.y)
                coherence_loss = nn.functional.mse_loss(
                    torch.log(predictions['coherence']), 
                    torch.log(batch.coherence)
                )
                loss = spectrum_loss + 0.1 * coherence_loss
                
                val_losses.append(loss.item())
        
        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)
        
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        
        scheduler.step(avg_val_loss)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_nv_gnn.pt')
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch:3d}: Train Loss = {avg_train_loss:.6f}, "
                  f"Val Loss = {avg_val_loss:.6f}")
    
    return history


def visualize_predictions(model, dataset, device='cpu', num_samples=3):
    """Visualize ODMR predictions vs ground truth"""
    
    model.eval()
    fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4 * num_samples))
    
    with torch.no_grad():
        for i in range(num_samples):
            idx = np.random.randint(len(dataset))
            data = dataset[idx].to(device)
            
            # Add batch dimension
            data.batch = torch.zeros(data.x.size(0), dtype=torch.long)
            
            predictions = model(data)
            
            # ODMR spectrum
            ax = axes[i, 0] if num_samples > 1 else axes[0]
            freq = data.freq_range.cpu().numpy()
            true_spectrum = data.y.cpu().numpy()
            pred_spectrum = predictions['spectrum'][0].cpu().numpy()
            
            ax.plot(freq, true_spectrum, 'b-', label='True', linewidth=2)
            ax.plot(freq, pred_spectrum, 'r--', label='Predicted', linewidth=2)
            ax.set_xlabel('Frequency (MHz)', fontsize=12)
            ax.set_ylabel('ODMR Signal', fontsize=12)
            ax.set_title(f'Sample {idx}: ODMR Spectrum', fontsize=14)
            ax.legend()
            ax.grid(True, alpha=0.3)
            
            # Coherence times comparison
            ax = axes[i, 1] if num_samples > 1 else axes[1]
            true_coh = data.coherence.cpu().numpy()
            pred_coh = predictions['coherence'][0].cpu().numpy()
            
            x = np.arange(2)
            width = 0.35
            ax.bar(x - width/2, true_coh / 1e3, width, label='True', alpha=0.7)
            ax.bar(x + width/2, pred_coh / 1e3, width, label='Predicted', alpha=0.7)
            ax.set_xticks(x)
            ax.set_xticklabels(['T₁', 'T₂'])
            ax.set_ylabel('Time (μs)', fontsize=12)
            ax.set_title(f'Sample {idx}: Coherence Times', fontsize=14)
            ax.legend()
            ax.set_yscale('log')
            ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig



In [7]:

# ============================================================================
# PART 4: Main Execution
# ============================================================================

if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Check for GPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Generate dataset
    print("\n" + "="*60)
    print("GENERATING SYNTHETIC ODMR DATA")
    print("="*60)
    
    train_dataset = NVGraphDataset(num_samples=500, n_nuclear=3)
    val_dataset = NVGraphDataset(num_samples=100, n_nuclear=3)
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    
    # Create model
    print("\n" + "="*60)
    print("BUILDING GRAPH NEURAL NETWORK")
    print("="*60)
    
    model = NVGNN(
        node_features=5,
        edge_features=3,
        hidden_dim=128,
        num_layers=4,
        spectrum_length=50
    )
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Train model
    print("\n" + "="*60)
    print("TRAINING MODEL")
    print("="*60)
    
    history = train_model(
        model, train_loader, val_loader,
        epochs=30, device=device
    )
    
    # Visualize results
    print("\n" + "="*60)
    print("GENERATING PREDICTIONS")
    print("="*60)
    
    # Load best model
    model.load_state_dict(torch.load('best_nv_gnn.pt'))
    model = model.to(device)
    
    fig = visualize_predictions(model, val_dataset, device=device, num_samples=3)
    plt.savefig('nv_odmr_predictions.png', dpi=150, bbox_inches='tight')
    print("\nPredictions saved to 'nv_odmr_predictions.png'")
    
    # Training history
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(history['train_loss'], label='Train Loss', linewidth=2)
    ax.plot(history['val_loss'], label='Validation Loss', linewidth=2)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Training History', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
    print("Training history saved to 'training_history.png'")
    
    print("\n" + "="*60)
    print("TRAINING COMPLETE!")
    print("="*60)
    print(f"Best validation loss: {min(history['val_loss']):.6f}")

Using device: cpu

GENERATING SYNTHETIC ODMR DATA
Generating 500 synthetic ODMR samples...
  Generated 0/500
  Generated 100/500
  Generated 200/500
  Generated 300/500
  Generated 400/500
Dataset generation complete!
Generating 100 synthetic ODMR samples...
  Generated 0/100
Dataset generation complete!

BUILDING GRAPH NEURAL NETWORK
Model parameters: 497,396

TRAINING MODEL


RuntimeError: The size of tensor a (50) must match the size of tensor b (800) at non-singleton dimension 1