In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AD RNA-seq Data Exploration & Target Definition\n",
    "## Alzheimer's Disease Prediction from Blood RNA-seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scripts.download_manager_v2 import RobustDataManager\n",
    "\n",
    "# Set style\n",
    "sns.set_style('whitegrid')\n",
    "plt.rcParams['figure.figsize'] = (12, 6)\n",
    "\n",
    "print(\"‚úÖ Setup complete\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "manager = RobustDataManager(\n",
    "    use_aws=True,\n",
    "    bucket_name='ad-rnaseq-prediction-data'\n",
    ")\n",
    "\n",
    "# Load from S3 or local\n",
    "expr, meta = manager.load_from_s3('GSE63061')\n",
    "print(f\"üìä Loaded data: {expr.shape[0]} samples, {expr.shape[1]} genes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Explore the metadata\n",
    "print(\"=\"*60)\n",
    "print(\"üîç METADATA EXPLORATION\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "print(\"\\nüìã Metadata columns:\")\n",
    "print(meta.columns.tolist())\n",
    "\n",
    "print(\"\\nüè• Diagnosis distribution:\")\n",
    "print(meta['diagnosis'].value_counts())\n",
    "print(f\"\\nKey finding: We have {(meta['diagnosis']=='MCI_converter').sum()} MCI converters!\")\n",
    "\n",
    "print(\"\\nüë• Demographics:\")\n",
    "print(f\"Age: {meta['age'].mean():.1f} ¬± {meta['age'].std():.1f} years\")\n",
    "print(f\"Sex distribution: {meta['sex'].value_counts().to_dict()}\")\n",
    "print(f\"MMSE: {meta['MMSE'].mean():.1f} ¬± {meta['MMSE'].std():.1f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define your prediction targets\n",
    "print(\"=\"*60)\n",
    "print(\"üéØ DEFINING PREDICTION TARGETS\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "# PRIMARY TARGET 1: MCI to AD Conversion (Binary)\n",
    "print(\"\\n1Ô∏è‚É£ PRIMARY TARGET: MCI ‚Üí AD Conversion\")\n",
    "mci_samples = meta[meta['diagnosis'].isin(['MCI', 'MCI_converter'])]\n",
    "print(f\"   MCI stable: {(mci_samples['diagnosis']=='MCI').sum()}\")\n",
    "print(f\"   MCI converters: {(mci_samples['diagnosis']=='MCI_converter').sum()}\")\n",
    "print(f\"   Total MCI samples: {len(mci_samples)}\")\n",
    "print(f\"   Conversion rate: {(mci_samples['diagnosis']=='MCI_converter').sum()/len(mci_samples)*100:.1f}%\")\n",
    "\n",
    "# Create binary labels\n",
    "meta['will_convert'] = (meta['diagnosis'] == 'MCI_converter').astype(int)\n",
    "\n",
    "# SECONDARY TARGET 2: Multi-class Classification\n",
    "print(\"\\n2Ô∏è‚É£ SECONDARY TARGET: Disease State Classification\")\n",
    "print(\"   Classes: Control, MCI, AD\")\n",
    "print(\"   Sample distribution:\")\n",
    "for diag in ['Control', 'MCI', 'AD']:\n",
    "    samples = meta[meta['diagnosis'].str.contains(diag, na=False)]\n",
    "    print(f\"   - {diag}: {len(samples)} samples\")\n",
    "\n",
    "# SECONDARY TARGET 3: Cognitive Score Prediction\n",
    "print(\"\\n3Ô∏è‚É£ ADDITIONAL TARGET: MMSE Score Prediction\")\n",
    "print(f\"   MMSE range: {meta['MMSE'].min():.0f} - {meta['MMSE'].max():.0f}\")\n",
    "print(f\"   Mean by group:\")\n",
    "for diag in meta['diagnosis'].unique():\n",
    "    mean_mmse = meta[meta['diagnosis']==diag]['MMSE'].mean()\n",
    "    print(f\"   - {diag}: {mean_mmse:.1f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the targets\n",
    "fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
    "\n",
    "# 1. Diagnosis distribution\n",
    "ax = axes[0, 0]\n",
    "meta['diagnosis'].value_counts().plot(kind='bar', ax=ax, color='skyblue')\n",
    "ax.set_title('Sample Distribution by Diagnosis', fontweight='bold')\n",
    "ax.set_xlabel('Diagnosis')\n",
    "ax.set_ylabel('Count')\n",
    "plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)\n",
    "\n",
    "# 2. Age distribution by diagnosis\n",
    "ax = axes[0, 1]\n",
    "for diag in meta['diagnosis'].unique():\n",
    "    subset = meta[meta['diagnosis']==diag]['age']\n",
    "    ax.hist(subset, alpha=0.5, label=diag, bins=15)\n",
    "ax.set_title('Age Distribution by Diagnosis', fontweight='bold')\n",
    "ax.set_xlabel('Age')\n",
    "ax.set_ylabel('Count')\n",
    "ax.legend()\n",
    "\n",
    "# 3. MMSE by diagnosis\n",
    "ax = axes[0, 2]\n",
    "diagnoses = ['Control', 'MCI', 'MCI_converter', 'AD']\n",
    "mmse_data = [meta[meta['diagnosis']==d]['MMSE'].values for d in diagnoses]\n",
    "bp = ax.boxplot(mmse_data, labels=diagnoses, patch_artist=True)\n",
    "colors = ['green', 'yellow', 'orange', 'red']\n",
    "for patch, color in zip(bp['boxes'], colors):\n",
    "    patch.set_facecolor(color)\n",
    "    patch.set_alpha(0.5)\n",
    "ax.set_title('MMSE Scores by Diagnosis', fontweight='bold')\n",
    "ax.set_ylabel('MMSE Score')\n",
    "plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)\n",
    "\n",
    "# 4. Conversion rate visualization\n",
    "ax = axes[1, 0]\n",
    "mci_data = meta[meta['diagnosis'].isin(['MCI', 'MCI_converter'])]\n",
    "conversion_counts = pd.Series({\n",
    "    'Stable MCI': (mci_data['diagnosis']=='MCI').sum(),\n",
    "    'MCI‚ÜíAD': (mci_data['diagnosis']=='MCI_converter').sum()\n",
    "})\n",
    "colors = ['#2ecc71', '#e74c3c']\n",
    "wedges, texts, autotexts = ax.pie(conversion_counts.values, \n",
    "                                   labels=conversion_counts.index, \n",
    "                                   autopct='%1.1f%%',\n",
    "                                   colors=colors,\n",
    "                                   startangle=90)\n",
    "ax.set_title('MCI Conversion Rate', fontweight='bold')\n",
    "\n",
    "# 5. Sex distribution\n",
    "ax = axes[1, 1]\n",
    "sex_by_diag = meta.groupby(['diagnosis', 'sex']).size().unstack(fill_value=0)\n",
    "sex_by_diag.plot(kind='bar', stacked=True, ax=ax, color=['lightblue', 'pink'])\n",
    "ax.set_title('Sex Distribution by Diagnosis', fontweight='bold')\n",
    "ax.set_xlabel('Diagnosis')\n",
    "ax.set_ylabel('Count')\n",
    "plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)\n",
    "ax.legend(title='Sex')\n",
    "\n",
    "# 6. Feature correlation heatmap\n",
    "ax = axes[1, 2]\n",
    "clinical_features = meta[['age', 'MMSE', 'will_convert']].copy()\n",
    "clinical_features['sex_binary'] = (meta['sex'] == 'M').astype(int)\n",
    "corr = clinical_features.corr()\n",
    "sns.heatmap(corr, annot=True, fmt='.2f', cmap='coolwarm', center=0, ax=ax)\n",
    "ax.set_title('Clinical Feature Correlations', fontweight='bold')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create analysis-ready datasets\n",
    "print(\"=\"*60)\n",
    "print(\"üì¶ CREATING ANALYSIS-READY DATASETS\")\n",
    "print(\"=\"*60)\n",
    "\n",
    "# Dataset 1: MCI Conversion Prediction\n",
    "mci_indices = meta['diagnosis'].isin(['MCI', 'MCI_converter'])\n",
    "X_mci = expr[mci_indices]\n",
    "y_mci = meta[mci_indices]['will_convert']\n",
    "meta_mci = meta[mci_indices]\n",
    "\n",
    "print(f\"\\n1Ô∏è‚É£ MCI Conversion Dataset:\")\n",
    "print(f\"   X shape: {X_mci.shape}\")\n",
    "print(f\"   y distribution: {y_mci.value_counts().to_dict()}\")\n",
    "print(f\"   Class balance: {y_mci.mean()*100:.1f}% converters\")\n",
    "\n",
    "# Dataset 2: Multi-class (excluding converters for cleaner classes)\n",
    "multiclass_indices = meta['diagnosis'].isin(['Control', 'MCI', 'AD'])\n",
    "X_multi = expr[multiclass_indices]\n",
    "y_multi = meta[multiclass_indices]['diagnosis']\n",
    "meta_multi = meta[multiclass_indices]\n",
    "\n",
    "print(f\"\\n2Ô∏è‚É£ Multi-class Dataset:\")\n",
    "print(f\"   X shape: {X_multi.shape}\")\n",
    "print(f\"   Classes: {y_multi.value_counts().to_dict()}\")\n",
    "\n",
    "# Dataset 3: Full dataset for MMSE regression\n",
    "X_full = expr\n",
    "y_mmse = meta['MMSE']\n",
    "meta_full = meta\n",
    "\n",
    "print(f\"\\n3Ô∏è‚É£ MMSE Regression Dataset:\")\n",
    "print(f\"   X shape: {X_full.shape}\")\n",
    "print(f\"   MMSE range: {y_mmse.min():.0f} - {y_mmse.max():.0f}\")\n",
    "\n",
    "# Save the prepared datasets\n",
    "import os\n",
    "os.makedirs('data/processed', exist_ok=True)\n",
    "\n",
    "# Save locally for quick access\n",
    "X_mci.to_csv('data/processed/X_mci_conversion.csv.gz', compression='gzip')\n",
    "y_mci.to_csv('data/processed/y_mci_conversion.csv')\n",
    "meta_mci.to_csv('data/processed/meta_mci_conversion.csv')\n",
    "\n",
    "print(\"\\n‚úÖ Datasets saved to data/processed/\")\n",
    "print(\"Ready for feature engineering and model building!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}