In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import subprocess

import semopy
import statsmodels.api as sm
from scipy import stats

In [None]:
# Function to parse genetic correlation results
def parse_genetic_correlation_results():
    genetic_corr_data = []
    
    # Process each pair of traits
    base_trait = "sporadic_miscarriage"
    for trait in traits:
        if trait != base_trait:
            result_file = f"{genetic_correlation_results_folder}/{base_trait}_{trait}.log"
            
            try:
                with open(result_file, 'r') as f:
                    lines = f.readlines()
                    for i, line in enumerate(lines):
                        if "Genetic Correlation:" in line:
                            parts = line.split()
                            rg = float(parts[2])
                            se = float(parts[4])
                            p = float(lines[i+1].split()[2])
                            
                            genetic_corr_data.append({
                                'trait1': base_trait,
                                'trait2': trait,
                                'rg': rg,
                                'se': se,
                                'p': p
                            })
                            break
            except FileNotFoundError:
                print(f"Warning: Results for {base_trait} and {trait} not found")
    
    return pd.DataFrame(genetic_corr_data)

# Function to extract heritability results
def extract_heritability_results():
    h2_results = {}
    
    for trait in traits:
        result_file = f"{heritability_results_folder}/{trait}.log"
        
        try:
            with open(result_file, 'r') as f:
                lines = f.readlines()
                for line in lines:
                    # Print for debugging
                    if "h2:" in line:
                        print(f"Debug - Found h2 line in {trait}: {line.strip()}")
                        
                    # Match various h2 report formats
                    if "Total Observed scale h2:" in line:
                        try:
                            parts = line.strip().split(":")[1].strip().split("(")[0].strip()
                            h2 = float(parts)
                            
                            # Look for SE in the same line or next parts
                            se_part = line.strip().split("(")[1].split(")")[0] if "(" in line else "0.0"
                            se = float(se_part) if se_part.replace('.','',1).isdigit() else 0.0
                            
                            h2_results[trait] = {'h2': h2, 'se': se}
                            break
                        except Exception as e:
                            print(f"Error parsing h2 for {trait}: {e}")
                            print(f"Problematic line: {line}")
                    elif "h2:" in line and "(" in line:
                        try:
                            h2_part = line.split("h2:")[1].strip().split("(")[0].strip()
                            se_part = line.split("(")[1].split(")")[0].strip()
                            h2 = float(h2_part)
                            se = float(se_part)
                            h2_results[trait] = {'h2': h2, 'se': se}
                            break
                        except Exception as e:
                            print(f"Error parsing h2 format 2 for {trait}: {e}")
                            print(f"Problematic line: {line}")
        except FileNotFoundError:
            print(f"Warning: Heritability results for {trait} not found")
        
        # If trait wasn't successfully parsed, add a placeholder
        if trait not in h2_results:
            print(f"Using placeholder for {trait} heritability")
            h2_results[trait] = {'h2': 0.1, 'se': 0.05}  # Placeholder values
    
    return pd.DataFrame.from_dict(h2_results, orient='index')

# Create SEM model based on genetic correlations and heritability
def build_sem_model():
    gc_data = parse_genetic_correlation_results()
    h2_data = extract_heritability_results()
    
    sig_corr = gc_data[gc_data['p'] < 0.05].copy()
    
    # Prepare model specification
    model_spec = """
    # Measurement model
    LatentalFactor =~ age_at_first_birth + number_of_children_born + income
    
    # Structural model
    sporadic_miscarriage ~ LatentalFactor
    """
    
    print("Significant genetic correlations:")
    print(sig_corr)
    
    # If there are no significant correlations, provide a warning
    if len(sig_corr) == 0:
        print("No significant genetic correlations found. Using default model specification.")
    
    return model_spec, sig_corr, h2_data

In [None]:
# Run SEM analysis
def run_sem_analysis():
    model_spec, sig_corr, h2_data = build_sem_model()
    
    n_samples = 1000
    
    # Create correlation matrix from genetic correlations
    traits_list = list(set(sig_corr['trait1'].tolist() + sig_corr['trait2'].tolist()))
    corr_matrix = np.eye(len(traits_list))
    
    # Fill correlation matrix
    for i in range(len(traits_list)):
        for j in range(i+1, len(traits_list)):
            trait1, trait2 = traits_list[i], traits_list[j]
            # Find correlation between these traits
            corr_row = sig_corr[(sig_corr['trait1'] == trait1) & (sig_corr['trait2'] == trait2) | 
                               (sig_corr['trait1'] == trait2) & (sig_corr['trait2'] == trait1)]
            
            if len(corr_row) > 0:
                corr_val = corr_row['rg'].values[0]
                corr_matrix[i, j] = corr_val
                corr_matrix[j, i] = corr_val
    
    # Generate multivariate normal data based on correlations
    np.random.seed(42)
    sim_data = pd.DataFrame(
        np.random.multivariate_normal(
            mean=np.zeros(len(traits_list)),
            cov=corr_matrix,
            size=n_samples
        ),
        columns=traits_list
    )
    
    print("Simulated data correlation matrix:")
    print(sim_data.corr())
    
    # Run SEM analysis if we have enough significant correlations
    if len(traits_list) >= 3:
        try:
            # Create and fit the SEM model
            model = semopy.Model(model_spec)
            model.fit(sim_data)
            
            # Show model results
            print("\nSEM Model Summary:")
            print(model.inspect())
            
            # Plot the model
            plt.figure(figsize=(12, 8))
            semopy.semplot(model, "sem_model.png")
            plt.title("Structural Equation Model")
            plt.tight_layout()
            plt.show()
            
            return model
        except Exception as e:
            print(f"Error running SEM: {e}")
            print("Consider modifying the model specification based on your data.")
    else:
        print("Not enough significant correlations to build a meaningful SEM model.")
        return None