# 🧬 Enhanced Drug Optimization RL (Discovery2 + EvE Bio)

**Advanced reinforcement learning for drug discovery with comprehensive analysis**

## 🚀 Features
1. **Multi-target analysis** - Compare BTK, EGFR, ALK, and other targets
2. **Hyperparameter optimization** - Grid search over learning rates, discount factors
3. **Chemical visualization** - RDKit structure rendering and property calculation
4. **Portfolio optimization** - Pareto frontier analysis for efficacy-safety-selectivity
5. **Persistence** - Save models and results to Google Drive
6. **Statistical rigor** - Confidence intervals, hypothesis testing, effect sizes
7. **MILES concepts** - MoE routing and distributed training simulation

---


In [None]:
# @title 📦 Install Enhanced Dependencies
!pip -q install --upgrade pip
!pip -q install numpy pandas matplotlib seaborn scikit-learn
!pip -q install gymnasium==0.29.1 joblib statsmodels lightgbm
!pip -q install datasets huggingface_hub
!pip -q install rdkit-pypi  # Chemical informatics
!pip -q install plotly  # Interactive visualizations
!pip -q install optuna  # Hyperparameter optimization
!pip -q install scipy  # Statistical tests
print('✓ All packages installed successfully!')

In [None]:
# @title 💾 Mount Google Drive for Persistence
from google.colab import drive
from pathlib import Path
import os

drive.mount('/content/drive')

# Create project directory in Drive
DRIVE_PROJECT_DIR = Path('/content/drive/MyDrive/DrugRL_Project')
DRIVE_PROJECT_DIR.mkdir(parents=True, exist_ok=True)

print(f'✓ Project directory: {DRIVE_PROJECT_DIR}')
print(f'  All results will be saved here for persistence across sessions')

In [None]:
# @title 🔧 Enhanced Setup with Directory Structure
import os
from pathlib import Path
from huggingface_hub import login

# Local runtime directories
BASE_DIR = Path('/content')
DATA_DIR = BASE_DIR / 'data'
MODELS_DIR = BASE_DIR / 'models'
OUT_DIR = BASE_DIR / 'outputs'
CACHE_DIR = BASE_DIR / 'cache'

for d in [DATA_DIR, MODELS_DIR, OUT_DIR, CACHE_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Drive subdirectories for long-term storage
DRIVE_RESULTS = DRIVE_PROJECT_DIR / 'results'
DRIVE_MODELS = DRIVE_PROJECT_DIR / 'trained_models'
DRIVE_FIGURES = DRIVE_PROJECT_DIR / 'figures'
DRIVE_CHECKPOINTS = DRIVE_PROJECT_DIR / 'checkpoints'

for d in [DRIVE_RESULTS, DRIVE_MODELS, DRIVE_FIGURES, DRIVE_CHECKPOINTS]:
    d.mkdir(parents=True, exist_ok=True)

print('✓ Directory structure created:')
print(f'  Local: {OUT_DIR}')
print(f'  Drive: {DRIVE_PROJECT_DIR}')

# HF login
hf_token = os.environ.get('HF_TOKEN', '')
if hf_token:
    login(token=hf_token)
    print('✓ Logged into Hugging Face')
else:
    print('ℹ️  No HF_TOKEN found (add as Colab Secret if needed)')

In [None]:
# @title 📝 Write Enhanced Environment with Chemical Features
%%writefile drug_rl_environment_enhanced.py
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pandas as pd
import joblib
from sklearn.preprocessing import MinMaxScaler
try:
    from rdkit import Chem
    from rdkit.Chem import Descriptors, AllChem
    RDKIT_AVAILABLE = True
except ImportError:
    RDKIT_AVAILABLE = False
    print('Warning: RDKit not available, chemical features disabled')

class DrugOptimizationEnvEnhanced(gym.Env):
    """Enhanced drug optimization environment with chemical features"""
    metadata = {'render_modes': ['human'], 'render_fps': 30}

    def __init__(self, drug_target_data_path: str, promiscuity_data_path: str,
                 cytotox_model_path: str, target_gene: str, max_steps: int = 10,
                 efficacy_weight: float = 0.4, safety_weight: float = 0.4,
                 selectivity_weight: float = 0.2, use_chemical_features: bool = True):
        super().__init__()
        self.target_gene = target_gene
        self.max_steps = max_steps
        self.current_step = 0
        self.visited_compounds = set()
        self.use_chemical_features = use_chemical_features and RDKIT_AVAILABLE
        self.episode_history = []  # Track all episodes

        # Load data (same as before)
        self.drug_target_df = pd.read_csv(drug_target_data_path)
        self.drug_target_df = self.drug_target_df[
            self.drug_target_df['target__gene'] == target_gene
        ]
        if self.drug_target_df.empty:
            raise ValueError(f'No data for target {target_gene}')

        self.compounds = self.drug_target_df['compound_id'].unique().tolist()
        self.n_compounds = len(self.compounds)
        self.compound_to_idx = {c: i for i, c in enumerate(self.compounds)}

        self.promiscuity_df = pd.read_csv(promiscuity_data_path)
        if 'cmpd_id' in self.promiscuity_df.columns:
            self.promiscuity_df = self.promiscuity_df.rename(
                columns={'cmpd_id': 'compound_id'}
            )
        self.promiscuity_df = self.promiscuity_df[
            self.promiscuity_df['compound_id'].isin(self.compounds)
        ]
        self.promiscuity_scores = self.promiscuity_df.set_index(
            'compound_id'
        )['promiscuity_score'].to_dict()

        self.cytotox_model = joblib.load(cytotox_model_path)

        # Compute chemical features if SMILES available
        self.chemical_features = {}
        if self.use_chemical_features and 'compound_smiles' in self.drug_target_df.columns:
            self._compute_chemical_features()

        self.action_space = spaces.Discrete(self.n_compounds)
        self.observation_space = spaces.Discrete(1)

        self.efficacy_weight = efficacy_weight
        self.safety_weight = safety_weight
        self.selectivity_weight = selectivity_weight

        self.efficacy_scaler = MinMaxScaler()
        self.safety_scaler = MinMaxScaler()
        self.selectivity_scaler = MinMaxScaler()
        self._precalculate_scaling_bounds()

    def _compute_chemical_features(self):
        """Compute RDKit molecular descriptors"""
        for _, row in self.drug_target_df.iterrows():
            cid = row['compound_id']
            smiles = row.get('compound_smiles', '')
            if smiles and cid not in self.chemical_features:
                mol = Chem.MolFromSmiles(smiles)
                if mol:
                    self.chemical_features[cid] = {
                        'mw': Descriptors.MolWt(mol),
                        'logp': Descriptors.MolLogP(mol),
                        'hbd': Descriptors.NumHDonors(mol),
                        'hba': Descriptors.NumHAcceptors(mol),
                        'tpsa': Descriptors.TPSA(mol),
                        'rotatable_bonds': Descriptors.NumRotatableBonds(mol),
                        'aromatic_rings': Descriptors.NumAromaticRings(mol)
                    }

    def _precalculate_scaling_bounds(self):
        max_act = self.drug_target_df['outcome_max_activity'].max()
        min_act = self.drug_target_df['outcome_max_activity'].min()
        self.efficacy_scaler.fit(np.array([[min_act], [max_act]]))
        self.safety_scaler.fit(np.array([[0.0], [1.0]]))
        max_prom = self.promiscuity_df['promiscuity_score'].max()
        min_prom = self.promiscuity_df['promiscuity_score'].min()
        self.selectivity_scaler.fit(np.array([[min_prom], [max_prom]]))

    def _get_obs(self):
        return 0

    def _get_info(self, compound_id=None, efficacy=None, safety=None, selectivity=None):
        info = {
            'current_step': self.current_step,
            'max_steps': self.max_steps,
            'n_compounds': self.n_compounds,
            'visited_compounds_count': len(self.visited_compounds)
        }
        if compound_id:
            info.update({
                'compound_id': compound_id,
                'efficacy': efficacy,
                'safety': safety,
                'selectivity': selectivity,
                'chemical_features': self.chemical_features.get(compound_id, {})
            })
        return info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_step = 0
        self.visited_compounds = set()
        self.current_episode_data = []
        return self._get_obs(), self._get_info()

    def step(self, action: int):
        self.current_step += 1
        compound_id = self.compounds[action]

        if compound_id in self.visited_compounds:
            reward = -10.0
            return self._get_obs(), reward, True, False, self._get_info(compound_id)

        self.visited_compounds.add(compound_id)

        # Get efficacy
        efficacy_data = self.drug_target_df[
            self.drug_target_df['compound_id'] == compound_id
        ]
        efficacy = efficacy_data['outcome_max_activity'].iloc[0] if not efficacy_data.empty else 0.0

        # Get safety
        promiscuity = self.promiscuity_scores.get(compound_id, 0.0)
        try:
            safety_pred = self.cytotox_model.predict_proba(
                np.array([[promiscuity]])
            )[:, 1][0]
            safety = 1.0 - safety_pred
        except:
            safety = 0.5

        # Get selectivity
        max_prom = self.promiscuity_df['promiscuity_score'].max()
        selectivity = 1.0 - (promiscuity / max_prom) if max_prom > 0 else 0.5
        selectivity = max(0.0, selectivity)

        # Scale and compute reward
        scaled_eff = self.efficacy_scaler.transform([[efficacy]])[0][0]
        scaled_safe = self.safety_scaler.transform([[safety]])[0][0]
        scaled_sel = self.selectivity_scaler.transform([[promiscuity]])[0][0]
        scaled_sel = 1.0 - scaled_sel

        reward = (
            self.efficacy_weight * scaled_eff +
            self.safety_weight * scaled_safe +
            self.selectivity_weight * scaled_sel
        )

        # Store episode data
        self.current_episode_data.append({
            'step': self.current_step,
            'compound_id': compound_id,
            'efficacy': efficacy,
            'safety': safety,
            'selectivity': selectivity,
            'reward': reward
        })

        terminated = self.current_step >= self.max_steps
        if terminated:
            self.episode_history.append(self.current_episode_data.copy())

        return self._get_obs(), float(reward), terminated, False, self._get_info(
            compound_id, efficacy, safety, selectivity
        )

    def render(self, mode='human'):
        if mode == 'human':
            print(f'Step: {self.current_step}/{self.max_steps}, '
                  f'Explored: {len(self.visited_compounds)}')

    def close(self):
        pass

print('✓ Enhanced environment created')

In [None]:
# @title 📥 Download Discovery2 Data
from huggingface_hub import hf_hub_download

promiscuity_csv = hf_hub_download(
    repo_id='pageman/discovery2-results',
    repo_type='dataset',
    filename='discovery2_promiscuity_scores.csv',
    local_dir=str(DATA_DIR),
    local_dir_use_symlinks=False
)

cubic_model_path = hf_hub_download(
    repo_id='pageman/discovery2-cytotoxicity-models',
    filename='cubic_logistic_model.pkl',
    local_dir=str(MODELS_DIR),
    local_dir_use_symlinks=False
)

print('✓ Downloaded Discovery2 artifacts')
print(f'  Promiscuity: {promiscuity_csv}')
print(f'  Cytotox model: {cubic_model_path}')

In [None]:
# @title 🎯 Multi-Target Dataset Loading
import pandas as pd
from datasets import load_dataset

# Define multiple targets to analyze
TARGET_GENES = ['BTK', 'EGFR', 'ALK', 'BRAF']  # Modify as needed
LOCAL_CSV = BASE_DIR / 'drug-target-activity.csv'

def standardize_columns(df):
    target_candidates = ['target__gene', 'target_gene', 'target_gene_symbol']
    compound_candidates = ['compound_id', 'drug_id', 'compound']
    target_col = next((c for c in target_candidates if c in df.columns), None)
    compound_col = next((c for c in compound_candidates if c in df.columns), None)
    if not target_col or not compound_col:
        raise KeyError('Missing required columns')
    df = df.rename(columns={target_col: 'target__gene', compound_col: 'compound_id'})
    for r in ['outcome_is_active', 'outcome_max_activity']:
        if r not in df.columns:
            raise KeyError(f'Missing {r}')
    return df

# Load dataset
dataset_loaded = False
if LOCAL_CSV.exists():
    try:
        print(f'Loading local CSV: {LOCAL_CSV}')
        df = pd.read_csv(LOCAL_CSV, low_memory=False)
        df = standardize_columns(df)
        print(f'✓ Loaded {len(df):,} rows from local CSV')
        dataset_loaded = True
    except Exception as e:
        print(f'⚠️  Local CSV failed: {e}')

if not dataset_loaded:
    try:
        print('Loading EvE Bio dataset from HuggingFace...')
        ds = load_dataset('eve-bio/drug-target-activity', split='train')
        df = ds.to_pandas()
        df = standardize_columns(df)
        print(f'✓ Loaded {len(df):,} rows from HF')
        dataset_loaded = True
    except Exception as e:
        print(f'⚠️  HF dataset failed: {e}')
        print('Creating synthetic dataset...')
        prom = pd.read_csv(promiscuity_csv)
        df_list = []
        for target in TARGET_GENES:
            demo = prom[['compound_id']].sample(n=min(200, len(prom)), random_state=42).copy()
            demo['target__gene'] = target
            demo['outcome_max_activity'] = (demo['compound_id'].astype('category').cat.codes % 101).astype(float)
            demo['outcome_is_active'] = demo['outcome_max_activity'] >= 50.0
            df_list.append(demo)
        df = pd.concat(df_list, ignore_index=True)
        print(f'✓ Created synthetic dataset with {len(df):,} rows')

# Create target-specific CSVs
target_csv_paths = {}
for target in TARGET_GENES:
    df_t = df[df['target__gene'] == target].copy()
    if len(df_t) > 0:
        csv_path = DATA_DIR / f'drug_target_activity_{target}.csv'
        df_t[['compound_id', 'target__gene', 'outcome_is_active', 'outcome_max_activity']].to_csv(
            csv_path, index=False
        )
        target_csv_paths[target] = csv_path
        print(f'  {target}: {len(df_t):,} compounds → {csv_path.name}')
    else:
        print(f'  {target}: No data found')

print(f'\n✓ Ready to analyze {len(target_csv_paths)} targets')

In [None]:
# @title 🔬 Hyperparameter Sweep with Optuna
import optuna
from drug_rl_environment_enhanced import DrugOptimizationEnvEnhanced
import sys
sys.path.insert(0, str(BASE_DIR))
from drug_rl_training import QLearningAgent, train_agent, evaluate_agent

# Choose a target for hyperparameter tuning
TUNE_TARGET = list(target_csv_paths.keys())[0]
print(f'Tuning hyperparameters for target: {TUNE_TARGET}')

def objective(trial):
    # Sample hyperparameters
    lr = trial.suggest_float('learning_rate', 0.01, 0.5, log=True)
    gamma = trial.suggest_float('discount_factor', 0.8, 0.99)
    epsilon_decay = trial.suggest_float('epsilon_decay', 0.95, 0.999)
    
    efficacy_weight = trial.suggest_float('efficacy_weight', 0.2, 0.6)
    safety_weight = trial.suggest_float('safety_weight', 0.2, 0.6)
    selectivity_weight = 1.0 - efficacy_weight - safety_weight
    
    if selectivity_weight < 0.05:
        return -1000  # Invalid
    
    # Create environment
    env = DrugOptimizationEnvEnhanced(
        drug_target_data_path=str(target_csv_paths[TUNE_TARGET]),
        promiscuity_data_path=str(promiscuity_csv),
        cytotox_model_path=str(cubic_model_path),
        target_gene=TUNE_TARGET,
        max_steps=10,
        efficacy_weight=efficacy_weight,
        safety_weight=safety_weight,
        selectivity_weight=selectivity_weight
    )
    
    # Train agent
    agent = QLearningAgent(
        n_actions=env.n_compounds,
        learning_rate=lr,
        discount_factor=gamma,
        epsilon_start=1.0,
        epsilon_end=0.01,
        epsilon_decay=epsilon_decay
    )
    
    train_agent(env, agent, n_episodes=100, max_steps=10, verbose=False)
    
    # Evaluate
    eval_stats = evaluate_agent(env, agent, n_episodes=10)
    return eval_stats['avg_reward']

# Run optimization
print('Starting hyperparameter optimization (this may take a few minutes)...')
study = optuna.create_study(direction='maximize', study_name='drug_rl_tuning')
study.optimize(objective, n_trials=20, show_progress_bar=True)

print('\n' + '='*80)
print('HYPERPARAMETER OPTIMIZATION RESULTS')
print('='*80)
print(f'Best value: {study.best_value:.3f}')
print(f'Best params: {study.best_params}')

# Save results
trials_df = study.trials_dataframe()
trials_df.to_csv(DRIVE_RESULTS / f'optuna_trials_{TUNE_TARGET}.csv', index=False)
print(f'\n✓ Saved trials to {DRIVE_RESULTS / f"optuna_trials_{TUNE_TARGET}.csv"}')

# Store best params for later use
best_params = study.best_params

In [None]:
# @title 🎯 Multi-Target Training & Comparison
from drug_rl_environment_enhanced import DrugOptimizationEnvEnhanced
from drug_rl_training import QLearningAgent, train_agent, evaluate_agent
import joblib

# Train agents for each target
multi_target_results = {}
trained_agents = {}
environments = {}

for target, csv_path in target_csv_paths.items():
    print(f'\n{"="*80}')
    print(f'Training agent for target: {target}')
    print(f'{"="*80}')
    
    # Create environment
    env = DrugOptimizationEnvEnhanced(
        drug_target_data_path=str(csv_path),
        promiscuity_data_path=str(promiscuity_csv),
        cytotox_model_path=str(cubic_model_path),
        target_gene=target,
        max_steps=10,
        efficacy_weight=best_params.get('efficacy_weight', 0.4),
        safety_weight=best_params.get('safety_weight', 0.4),
        selectivity_weight=best_params.get('selectivity_weight', 0.2)
    )
    environments[target] = env
    
    # Create agent with best hyperparameters
    agent = QLearningAgent(
        n_actions=env.n_compounds,
        learning_rate=best_params.get('learning_rate', 0.1),
        discount_factor=best_params.get('discount_factor', 0.95),
        epsilon_start=1.0,
        epsilon_end=0.01,
        epsilon_decay=best_params.get('epsilon_decay', 0.995)
    )
    
    # Train
    training_stats = train_agent(env, agent, n_episodes=200, max_steps=10, verbose=True)
    
    # Evaluate
    eval_stats = evaluate_agent(env, agent, n_episodes=10)
    
    multi_target_results[target] = {
        'rewards': training_stats['rewards'],
        'eval_mean': eval_stats['avg_reward'],
        'eval_std': eval_stats['std_reward']
    }
    trained_agents[target] = agent
    
    # Save agent to Drive
    agent_path = DRIVE_MODELS / f'agent_{target}.pkl'
    joblib.dump(agent, agent_path)
    print(f'\n✓ Saved trained agent to {agent_path}')
    
    print(f'Evaluation: {eval_stats["avg_reward"]:.3f} ± {eval_stats["std_reward"]:.3f}')

print(f'\n{"="*80}')
print('MULTI-TARGET TRAINING COMPLETE')
print(f'{"="*80}')

In [None]:
# @title 📊 Comprehensive Analysis with Enhanced Tools
# Upload enhanced analysis module
!wget -q -O drug_rl_enhanced_analysis.py https://raw.githubusercontent.com/YOUR_REPO/drug_rl_enhanced_analysis.py || echo 'Note: Using embedded analysis'

# For now, embed simplified version
from drug_rl_enhanced_analysis import run_comprehensive_analysis

# Run analysis for each target
for target in multi_target_results.keys():
    print(f'\nAnalyzing target: {target}')
    output_dir = DRIVE_RESULTS / target
    output_dir.mkdir(exist_ok=True)
    
    analyzer, summary = run_comprehensive_analysis(
        training_results={target: multi_target_results[target]},
        env=environments[target],
        agent=trained_agents[target],
        output_dir=output_dir
    )
    print(f'✓ Analysis complete for {target}')
    print(f'  Results saved to: {output_dir}')

In [None]:
# @title 🧪 Chemical Structure Visualization (Top Compounds)
try:
    from rdkit import Chem
    from rdkit.Chem import Draw, Descriptors
    import matplotlib.pyplot as plt
    from PIL import Image
    import io
    
    # Select a target
    VIZ_TARGET = list(trained_agents.keys())[0]
    print(f'Visualizing top compounds for: {VIZ_TARGET}')
    
    # Get top 9 compounds by Q-value
    agent = trained_agents[VIZ_TARGET]
    env = environments[VIZ_TARGET]
    q_values = agent.q_table[0]
    top_9_idx = np.argsort(q_values)[-9:][::-1]
    
    # Check if SMILES data available
    df = pd.read_csv(target_csv_paths[VIZ_TARGET])
    if 'compound_smiles' not in df.columns:
        print('⚠️  No SMILES data in dataset, skipping visualization')
    else:
        fig, axes = plt.subplots(3, 3, figsize=(15, 15))
        fig.suptitle(f'Top 9 Compounds for {VIZ_TARGET}', fontsize=16, fontweight='bold')
        
        for idx, ax in zip(top_9_idx, axes.flatten()):
            cmpd_id = env.compounds[idx]
            cmpd_data = df[df['compound_id'] == cmpd_id]
            
            if not cmpd_data.empty and 'compound_smiles' in cmpd_data.columns:
                smiles = cmpd_data['compound_smiles'].iloc[0]
                mol = Chem.MolFromSmiles(smiles) if smiles else None
                
                if mol:
                    img = Draw.MolToImage(mol, size=(300, 300))
                    ax.imshow(img)
                    ax.axis('off')
                    
                    # Add metadata
                    mw = Descriptors.MolWt(mol)
                    logp = Descriptors.MolLogP(mol)
                    q_val = q_values[idx]
                    
                    ax.set_title(
                        f'{cmpd_id[:10]}...\nQ={q_val:.2f}\nMW={mw:.1f} LogP={logp:.1f}',
                        fontsize=10
                    )
                else:
                    ax.text(0.5, 0.5, 'Invalid\nSMILES', ha='center', va='center')
                    ax.axis('off')
        
        plt.tight_layout()
        viz_path = DRIVE_FIGURES / f'top_compounds_{VIZ_TARGET}.png'
        plt.savefig(viz_path, dpi=300, bbox_inches='tight')
        print(f'✓ Saved visualization to {viz_path}')
        plt.show()
        
except ImportError:
    print('⚠️  RDKit not available, skipping chemical visualization')
except Exception as e:
    print(f'⚠️  Visualization error: {e}')

In [None]:
# @title 🚀 MILES/MoE Concepts Demo
from miles_concepts_drug_rl import demonstrate_miles_concepts
moe, rollout_system = demonstrate_miles_concepts()

In [None]:
# @title 📋 Generate Summary Report
import json
from datetime import datetime

summary_report = {
    'timestamp': datetime.now().isoformat(),
    'targets_analyzed': list(multi_target_results.keys()),
    'best_hyperparameters': best_params,
    'results_by_target': {},
    'files_generated': {
        'models': [str(p) for p in DRIVE_MODELS.glob('*.pkl')],
        'figures': [str(p) for p in DRIVE_FIGURES.glob('*.png')],
        'results': [str(p) for p in DRIVE_RESULTS.glob('**/*.csv')]
    }
}

for target, results in multi_target_results.items():
    summary_report['results_by_target'][target] = {
        'final_reward': float(results['rewards'][-1]),
        'mean_reward': float(np.mean(results['rewards'])),
        'eval_performance': f"{results['eval_mean']:.3f} ± {results['eval_std']:.3f}",
        'n_compounds': len(environments[target].compounds)
    }

# Save report
report_path = DRIVE_PROJECT_DIR / 'experiment_summary.json'
with open(report_path, 'w') as f:
    json.dump(summary_report, f, indent=2)

print('='*80)
print('EXPERIMENT SUMMARY')
print('='*80)
print(json.dumps(summary_report, indent=2))
print(f'\n✓ Full report saved to: {report_path}')
print(f'\n✅ All results persisted to Google Drive: {DRIVE_PROJECT_DIR}')

## 🔧 Troubleshooting

### Common Issues

1. **RDKit installation fails**: Restart runtime, or skip chemical visualization
2. **EvE dataset access denied**: Add HF_TOKEN as Colab Secret after accepting dataset terms
3. **Out of memory**: Reduce n_trials in Optuna or n_episodes in training
4. **Drive quota exceeded**: Clear old files in DRIVE_PROJECT_DIR

### Performance Tips

- Use GPU runtime for faster hyperparameter tuning (though not required)
- Reduce TARGET_GENES list to 2-3 targets to save time
- Cache results in Drive and reload instead of retraining
