In [None]:
"""
CCA Alignment for X-ray and Optical GRB Data
=============================================

WHAT THIS DOES (Big Picture):
-----------------------------
We have GRB (gamma-ray burst) observations from two different instruments:
- X-ray telescope: gives us 10 features per GRB
- Optical telescope: gives us 4 features per GRB

Some GRBs were observed by BOTH telescopes (~87 "paired" GRBs).
Many GRBs were observed by only ONE telescope.

PROBLEM: We want to predict redshift (z) using ALL available GRBs,
but the X-ray and optical features live in different spaces.

SOLUTION: Use Canonical Correlation Analysis (CCA) to find a "shared latent space"
where both X-ray and optical features can be projected. Then train ONE model
on this shared space using ALL GRBs.

HOW CCA WORKS (Intuition):
--------------------------
CCA finds two projection matrices:
- Wx: projects X-ray features (10-dim) → latent space (3-dim)
- Wo: projects Optical features (4-dim) → latent space (3-dim)

These projections are chosen so that for PAIRED GRBs (same object observed 
by both telescopes), the X-ray projection and optical projection land in 
SIMILAR places in the latent space.

The "canonical correlations" tell us how well aligned the two projections are:
- Correlation = 1.0 means perfect alignment
- Correlation = 0.5 means moderate alignment
- Correlation = 0.0 means no relationship

WORKFLOW:
---------
1. Load X-ray and optical data
2. Find paired GRBs (observed by both telescopes)
3. Fit CCA on paired GRBs to learn Wx and Wo
4. Project ALL GRBs into latent space:
   - Paired: average of X-ray and optical projections
   - X-ray only: use Wx projection
   - Optical only: use Wo projection
5. Output combined dataset for head model training
"""

# =============================================================================
# IMPORTS
# =============================================================================
import pandas as pd
import numpy as np
from scipy.linalg import svd
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')


# =============================================================================
# DATA LOADING FUNCTIONS
# =============================================================================

def load_xray_data(filepath='/Users/peterlin/Desktop/GRB_combination_proj/data/x-ray_data.csv'):
    """Load X-ray GRB data from CSV."""
    xray = pd.read_csv(filepath, index_col=0)
    return xray


def load_optical_data(filepath='/Users/peterlin/Desktop/GRB_combination_proj/data/optical_data.txt'):
    """Load optical GRB data from tab-separated file."""
    
    # ---------- Read file lines ----------
    with open(filepath, 'r') as f:
        lines = f.readlines()

    data_rows = []
    
    # ---------- Parse each line (skip header) ----------
    for line in lines[1:]:
        
        # ---------- Split by tab and clean whitespace ----------
        parts = line.strip().split('\t')
        parts = [p.strip() for p in parts]
        grb_id = parts[0]
        
        # ---------- Extract numeric columns from end of row ----------
        # (File has messy formatting, but last 17 columns are consistent)
        if len(parts) >= 17:
            numeric_parts = parts[-17:]
            try:
                data_rows.append({
                    'GRB': grb_id,
                    'z': float(numeric_parts[0]),
                    'T90': float(numeric_parts[1]),
                    'Class': numeric_parts[2],
                    'logFa': float(numeric_parts[3]),
                    'logFaErr': float(numeric_parts[4]),
                    'logTa': float(numeric_parts[5]),
                    'logTaErr': float(numeric_parts[6]),
                    'Alpha': float(numeric_parts[7]),
                    'AlphaErr': float(numeric_parts[8]),
                    'beta': float(numeric_parts[9]),
                    'betaErr': float(numeric_parts[10])
                })
            except (ValueError, IndexError):
                pass

    return pd.DataFrame(data_rows)


def normalize_grb_id(grb_id):
    """
    Normalize GRB IDs for matching between datasets.
    
    X-ray naming is inconsistent: sometimes "GRB050319", sometimes "GRB050416A"
    Optical always has suffix: "050319A", "050416A"
    
    This strips "GRB" prefix and trailing "A" (but keeps B, C suffixes).
    """
    
    # ---------- Remove "GRB" prefix ----------
    s = str(grb_id).upper().replace('GRB', '').strip()
    
    # ---------- Remove trailing 'A' only (keep B, C for different GRBs) ----------
    if s.endswith('A') and s[:-1].isdigit():
        return s[:-1]
    return s


# =============================================================================
# CCA IMPLEMENTATION
# =============================================================================

def plain_cca(X, O, r=3):
    """
    Plain Canonical Correlation Analysis.
    
    Finds projection matrices Wx and Wo that maximize correlation
    between X @ Wx and O @ Wo for paired samples.
    """
    
    # ---------- Get dimensions ----------
    n = X.shape[0]   # number of samples
    px = X.shape[1]  # number of X-ray features (10)
    po = O.shape[1]  # number of optical features (4)
    
    # ---------- Compute covariance matrices ----------
    Sxx = np.cov(X, rowvar=False)                    # X-ray with itself (10x10)
    Soo = np.cov(O, rowvar=False)                    # Optical with itself (4x4)
    Sxo = np.cov(X, O, rowvar=False)[:px, px:]       # Cross-covariance (10x4)
    
    # ---------- Handle edge case of 1D arrays ----------
    if Sxx.ndim == 0:
        Sxx = np.array([[Sxx]])
    if Soo.ndim == 0:
        Soo = np.array([[Soo]])
    
    # ---------- Add regularization for numerical stability ----------
    eps = 1e-8
    Sxx = Sxx + eps * np.eye(px)
    Soo = Soo + eps * np.eye(po)
    
    # ---------- Helper: compute matrix inverse square root ----------
    def inv_sqrt(S):
        eigvals, eigvecs = np.linalg.eigh(S)
        eigvals = np.maximum(eigvals, 1e-12)
        return eigvecs @ np.diag(1.0 / np.sqrt(eigvals)) @ eigvecs.T
    
    # ---------- Compute inverse square roots (whitening transforms) ----------
    Sxx_inv_sqrt = inv_sqrt(Sxx)
    Soo_inv_sqrt = inv_sqrt(Soo)
    
    # ---------- Compute whitened cross-covariance ----------
    K = Sxx_inv_sqrt @ Sxo @ Soo_inv_sqrt
    
    # ---------- SVD to get canonical directions ----------
    U, s, Vt = svd(K, full_matrices=False)
    
    # ---------- Keep first r components ----------
    r_eff = min(r, len(s), px, po)
    
    # ---------- Compute projection matrices ----------
    Wx = Sxx_inv_sqrt @ U[:, :r_eff]      # X-ray projection (10 x r)
    Wo = Soo_inv_sqrt @ Vt[:r_eff, :].T   # Optical projection (4 x r)
    
    return Wx, Wo, s[:r_eff]


def ridge_cca(X, O, r=3, lam_x=0.2, lam_o=0.2):
    """
    Ridge-regularized CCA.
    Use if plain CCA is unstable.
    """
    
    # ---------- Get dimensions ----------
    n = X.shape[0]
    px = X.shape[1]
    po = O.shape[1]
    
    # ---------- Compute covariance matrices ----------
    Sxx = np.cov(X, rowvar=False)
    Soo = np.cov(O, rowvar=False)
    Sxo = np.cov(X, O, rowvar=False)[:px, px:]
    
    # ---------- Handle edge case ----------
    if Sxx.ndim == 0:
        Sxx = np.array([[Sxx]])
    if Soo.ndim == 0:
        Soo = np.array([[Soo]])
    
    # ---------- Apply ridge regularization ----------
    tx = np.mean(np.diag(Sxx))
    to = np.mean(np.diag(Soo))
    Sxx_r = Sxx + (lam_x * tx) * np.eye(px)
    Soo_r = Soo + (lam_o * to) * np.eye(po)
    
    # ---------- Inverse square roots ----------
    def inv_sqrt(S):
        eigvals, eigvecs = np.linalg.eigh(S)
        eigvals = np.maximum(eigvals, 1e-12)
        return eigvecs @ np.diag(1.0 / np.sqrt(eigvals)) @ eigvecs.T
    
    Sxx_inv_sqrt = inv_sqrt(Sxx_r)
    Soo_inv_sqrt = inv_sqrt(Soo_r)
    
    # ---------- Whitened cross-covariance and SVD ----------
    K = Sxx_inv_sqrt @ Sxo @ Soo_inv_sqrt
    U, s, Vt = svd(K, full_matrices=False)
    
    # ---------- Keep first r components ----------
    r_eff = min(r, len(s), px, po)
    
    # ---------- Projection matrices ----------
    Wx = Sxx_inv_sqrt @ U[:, :r_eff]
    Wo = Soo_inv_sqrt @ Vt[:r_eff, :].T
    
    return Wx, Wo, s[:r_eff]


# =============================================================================
# CCA ALIGNMENT CLASS
# =============================================================================

class CCAAlignment:
    """
    Wrapper class for CCA alignment.
    Handles standardization, fitting, and transforming.
    """
    
    def __init__(self, r=3, use_ridge=False, lam_x=0.2, lam_o=0.2):
        self.r = r
        self.use_ridge = use_ridge
        self.lam_x = lam_x
        self.lam_o = lam_o
        
        # ---------- These get set during fit() ----------
        self.Wx = None
        self.Wo = None
        self.correlations = None
        self.x_scaler = None
        self.o_scaler = None
        self.x_features = None
        self.o_features = None
        
    def fit(self, X_paired, O_paired, x_features=None, o_features=None):
        """Fit CCA on paired GRB data."""
        
        self.x_features = x_features
        self.o_features = o_features
        
        # ---------- Convert DataFrames to arrays ----------
        if isinstance(X_paired, pd.DataFrame):
            X = X_paired.values
        else:
            X = X_paired
            
        if isinstance(O_paired, pd.DataFrame):
            O = O_paired.values
        else:
            O = O_paired
        
        # ---------- Standardize features (mean=0, std=1) ----------
        self.x_scaler = StandardScaler()
        self.o_scaler = StandardScaler()
        X_std = self.x_scaler.fit_transform(X)
        O_std = self.o_scaler.fit_transform(O)
        
        # ---------- Fit CCA ----------
        if self.use_ridge:
            self.Wx, self.Wo, self.correlations = ridge_cca(
                X_std, O_std, r=self.r, lam_x=self.lam_x, lam_o=self.lam_o
            )
        else:
            self.Wx, self.Wo, self.correlations = plain_cca(
                X_std, O_std, r=self.r
            )
        
        print(f"CCA fitted with r={len(self.correlations)} components")
        print(f"Canonical correlations: {self.correlations}")
        
        return self
    
    def transform_xray(self, X):
        """Project X-ray data into latent space: Z = standardize(X) @ Wx"""
        if isinstance(X, pd.DataFrame):
            X = X.values
        X_std = self.x_scaler.transform(X)
        return X_std @ self.Wx
    
    def transform_optical(self, O):
        """Project optical data into latent space: Z = standardize(O) @ Wo"""
        if isinstance(O, pd.DataFrame):
            O = O.values
        O_std = self.o_scaler.transform(O)
        return O_std @ self.Wo
    
    def transform_both(self, X, O):
        """Project paired data: Z = average of X-ray and optical projections"""
        Zx = self.transform_xray(X)
        Zo = self.transform_optical(O)
        return 0.5 * (Zx + Zo)
    
    def get_correlation_report(self):
        """Return summary of CCA fit."""
        return {
            'n_components': len(self.correlations),
            'canonical_correlations': self.correlations.tolist(),
            'x_features': self.x_features,
            'o_features': self.o_features,
            'Wx_shape': self.Wx.shape,
            'Wo_shape': self.Wo.shape
        }


# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == '__main__':
    
    # =========================================================================
    # STEP 1: LOAD DATA
    # =========================================================================
    print("Loading data...")
    xray = load_xray_data('/Users/peterlin/Desktop/GRB_combination_proj/data/x-ray_data.csv')
    opt = load_optical_data('/Users/peterlin/Desktop/GRB_combination_proj/data/optical_data.txt')
    
    # =========================================================================
    # STEP 2: NORMALIZE GRB IDs FOR MATCHING
    # =========================================================================
    # ---------- Add normalized ID column to each dataset ----------
    xray['norm_id'] = [normalize_grb_id(g) for g in xray.index]
    opt['norm_id'] = [normalize_grb_id(g) for g in opt['GRB']]
    
    # =========================================================================
    # STEP 3: MERGE TO FIND PAIRED GRBs
    # =========================================================================
    # ---------- Inner join on normalized ID ----------
    paired = pd.merge(
        xray.reset_index().rename(columns={'index': 'xray_id'}),
        opt,
        on='norm_id',
        suffixes=('_x', '_o')
    )
    
    # ---------- Print counts ----------
    print(f"\nNumber of paired GRBs: {len(paired)}")
    print(f"Number of X-ray only GRBs: {len(xray) - len(paired)}")
    print(f"Number of Optical only GRBs: {len(opt) - len(paired)}")
    
    # =========================================================================
    # STEP 4: DEFINE FEATURE COLUMNS
    # =========================================================================
    # ---------- X-ray features (10 total) ----------
    xray_features = ['log10Fa', 'log10Ta', 'Alpha_x', 'Beta', 'Gamma',
                     'T90_x', 'log10Fluence', 'log10PeakFlux', 'PhotonIndex', 'log10NH']
    
    # ---------- Optical features (4 total) ----------
    opt_features = ['logFa', 'logTa', 'Alpha_o', 'beta']
    
    # =========================================================================
    # STEP 5: CHECK FOR MISSING VALUES
    # =========================================================================
    print("\n=== Checking for missing values in paired data ===")
    for f in xray_features:
        if f in paired.columns:
            n_missing = paired[f].isna().sum()
            if n_missing > 0:
                print(f"  {f}: {n_missing} missing")
    
    for f in opt_features:
        if f in paired.columns:
            n_missing = paired[f].isna().sum()
            if n_missing > 0:
                print(f"  {f}: {n_missing} missing")
    
    # =========================================================================
    # STEP 6: DROP ROWS WITH MISSING VALUES
    # =========================================================================
    paired_clean = paired.dropna(subset=xray_features + opt_features)
    print(f"\nPaired GRBs after removing missing values: {len(paired_clean)}")
    
    # =========================================================================
    # STEP 7: EXTRACT FEATURE MATRICES
    # =========================================================================
    X_paired = paired_clean[xray_features]
    O_paired = paired_clean[opt_features]
    
    print(f"\nX-ray feature matrix shape: {X_paired.shape}")
    print(f"Optical feature matrix shape: {O_paired.shape}")
    
    # =========================================================================
    # STEP 8: FIT CCA ON PAIRED DATA
    # =========================================================================
    print("FITTING CCA ALIGNMENT")
    
    # ---------- Create CCA object with r=3 latent dimensions ----------
    cca = CCAAlignment(r=3, use_ridge=False)
    
    # ---------- Fit on paired data ----------
    cca.fit(X_paired, O_paired, x_features=xray_features, o_features=opt_features)
    
    # =========================================================================
    # STEP 9: PRINT CORRELATION COEFFICIENTS
    # =========================================================================
    print("\nCorrelation Coefficients for Mapping")
    for i, corr in enumerate(cca.correlations):
        print(f"  Component {i+1}: {corr:.4f}")
        if corr < 0.5:
            print(f"    WARNING: Correlation < 0.5 - check if variables are important predictors")
    
    # =========================================================================
    # STEP 10: PROJECT PAIRED DATA (VERIFICATION)
    # =========================================================================
    Z_paired = cca.transform_both(X_paired, O_paired)
    print(f"\nLatent representation shape: {Z_paired.shape}")
    
    # =========================================================================
    # STEP 11: SAVE PAIRED GRB LATENT FEATURES
    # =========================================================================
    print("SAVING RESULTS")
    
    # ---------- Create output dataframe ----------
    z_col = 'z' if 'z' in paired_clean.columns else 'z_o'
    output_paired = paired_clean[['xray_id', 'norm_id', 'Redshift_crosscheck', z_col]].copy()
    output_paired.columns = ['xray_id', 'grb_id', 'z_xray', 'z_optical']
    
    # ---------- Add latent features as columns ----------
    for i in range(Z_paired.shape[1]):
        output_paired[f'Z{i+1}'] = Z_paired[:, i]
    
    # ---------- Save to CSV ----------
    output_paired.to_csv('/Users/peterlin/Desktop/GRB_combination_proj/data/paired_grb_latent_features.csv', index=False)
    print(f"Saved: paired_grb_latent_features.csv ({len(output_paired)} GRBs)")
    
    # =========================================================================
    # STEP 12: SAVE CCA PARAMETERS
    # =========================================================================
    np.savez('/Users/peterlin/Desktop/GRB_combination_proj/data/cca_parameters.npz',
             Wx=cca.Wx,
             Wo=cca.Wo,
             correlations=cca.correlations,
             x_center=cca.x_scaler.mean_,
             x_scale=cca.x_scaler.scale_,
             o_center=cca.o_scaler.mean_,
             o_scale=cca.o_scaler.scale_,
             x_features=xray_features,
             o_features=opt_features)
    print("Saved: cca_parameters.npz")
    
    # =========================================================================
    # STEP 13: DEFINE ORIGINAL FEATURE NAMES (without _x/_o suffixes)
    # =========================================================================
    print("PROJECTING ALL GRBs INTO LATENT SPACE")
    
    xray_orig_features = ['log10Fa', 'log10Ta', 'Alpha', 'Beta', 'Gamma',
                          'T90', 'log10Fluence', 'log10PeakFlux', 'PhotonIndex', 'log10NH']
    opt_orig_features = ['logFa', 'logTa', 'Alpha', 'beta']
    
    all_results = []
    r = len(cca.correlations)
    
    # =========================================================================
    # STEP 14: PROJECT PAIRED GRBs (use both modalities, average projections)
    # =========================================================================
    paired_norm_ids = set(paired_clean['norm_id'])
    print(f"\nProjecting {len(paired_norm_ids)} paired GRBs...")
    
    for _, row in paired_clean.iterrows():
        
        # ---------- Get X-ray and optical features ----------
        X = row[xray_features].values.reshape(1, -1)
        O = row[opt_features].values.reshape(1, -1)
        
        # ---------- Project using both (average) ----------
        Z = cca.transform_both(X, O)
        
        # ---------- Store result ----------
        result = {
            'grb_id': row['norm_id'],
            'xray_id': row['xray_id'],
            'opt_id': row['GRB'],
            'z': row['Redshift_crosscheck'],
            'modality': 'both'
        }
        for i in range(r):
            result[f'Z{i+1}'] = Z[0, i]
        all_results.append(result)
    
    # =========================================================================
    # STEP 15: PROJECT X-RAY ONLY GRBs
    # =========================================================================
    # ---------- Get X-ray GRBs not in paired set ----------
    xray_only = xray[~xray['norm_id'].isin(paired_norm_ids)].dropna(subset=xray_orig_features)
    print(f"Projecting {len(xray_only)} X-ray only GRBs...")
    
    for idx, row in xray_only.iterrows():
        try:
            # ---------- Get X-ray features ----------
            X = row[xray_orig_features].values.reshape(1, -1)
            
            # ---------- Project using X-ray only ----------
            Z = cca.transform_xray(X)
            
            # ---------- Store result ----------
            result = {
                'grb_id': row['norm_id'],
                'xray_id': idx,
                'opt_id': None,
                'z': row['Redshift_crosscheck'],
                'modality': 'xray'
            }
            for i in range(r):
                result[f'Z{i+1}'] = Z[0, i]
            all_results.append(result)
        except (KeyError, ValueError):
            continue
    
    # =========================================================================
    # STEP 16: PROJECT OPTICAL ONLY GRBs
    # =========================================================================
    # ---------- Get optical GRBs not in paired set ----------
    opt_only = opt[~opt['norm_id'].isin(paired_norm_ids)].dropna(subset=opt_orig_features)
    print(f"Projecting {len(opt_only)} optical only GRBs...")
    
    for _, row in opt_only.iterrows():
        try:
            # ---------- Get optical features ----------
            O = row[opt_orig_features].values.reshape(1, -1)
            
            # ---------- Project using optical only ----------
            Z = cca.transform_optical(O)
            
            # ---------- Store result ----------
            result = {
                'grb_id': row['norm_id'],
                'xray_id': None,
                'opt_id': row['GRB'],
                'z': row['z'],
                'modality': 'optical'
            }
            for i in range(r):
                result[f'Z{i+1}'] = Z[0, i]
            all_results.append(result)
        except (KeyError, ValueError):
            continue
    
    # =========================================================================
    # STEP 17: CREATE COMBINED DATAFRAME
    # =========================================================================
    combined = pd.DataFrame(all_results)
    
    # =========================================================================
    # STEP 18: ADD MODALITY FLAGS FOR HEAD MODEL
    # =========================================================================
    # ---------- One-hot encode modality (reference = 'both') ----------
    combined['mod_xray'] = (combined['modality'] == 'xray').astype(int)
    combined['mod_optical'] = (combined['modality'] == 'optical').astype(int)
    
    # =========================================================================
    # STEP 19: SAVE COMBINED OUTPUT
    # =========================================================================
    combined.to_csv('/Users/peterlin/Desktop/GRB_combination_proj/data/all_grb_latent_features.csv', index=False)
    
    # =========================================================================
    # STEP 20: PRINT SUMMARY
    # =========================================================================
    print("\nDistribution by modality")
    print(combined['modality'].value_counts())
    
    print("Summary")
    print(f"""
CCA Alignment Complete!

Data:
  - Paired GRBs used for alignment: {len(paired_clean)}
  - Total GRBs with latent features: {len(combined)}
    - Paired (both): {(combined['modality'] == 'both').sum()}
    - X-ray only: {(combined['modality'] == 'xray').sum()}
    - Optical only: {(combined['modality'] == 'optical').sum()}

CCA Results:
  - Latent dimensions: {len(cca.correlations)}
  - Canonical correlations: {[f'{c:.4f}' for c in cca.correlations]}

Output files:
  - all_grb_latent_features.csv: All GRBs with Z1, Z2, Z3, modality flags
  - cca_parameters.npz: Projection matrices and scalers

Head model formula:
  log(1+z) ~ Z1 + Z2 + Z3 + mod_xray + mod_optical

NEXT STEPS:
  1. Your colleague trains the head model on this combined dataset
  2. Target variable: log(1+z)
  3. Features: Z1, Z2, Z3, mod_xray, mod_optical
  4. Recommended model: Ridge regression
""")