# SciTeX AI Module Tutorial

This notebook demonstrates the AI capabilities in SciTeX, including generative AI, classification, and machine learning utilities.

## 1. Setup and Imports

In [None]:
import scitex as stx
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification, make_blobs
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import warnings
warnings.filterwarnings('ignore')

## 2. Generative AI with GenAI

SciTeX provides a unified interface for multiple AI providers.

### 2.1 Basic Text Generation

In [None]:
# Note: Set your API keys as environment variables
# export OPENAI_API_KEY="your-key"
# export ANTHROPIC_API_KEY="your-key"

# Create a GenAI instance (defaults to OpenAI)
try:
    ai = stx.ai.GenAI(provider="openai")
    
    # Simple completion
    response = ai.complete("Explain phase-amplitude coupling in neuroscience in one paragraph.")
    print("Response:")
    print(response['response'])
    print(f"\nCost: ${response['total_cost']:.4f}")
    print(f"Tokens used: {response['total_tokens']}")
except Exception as e:
    print(f"Note: GenAI requires API keys. Error: {e}")
    print("\nTo use GenAI, set environment variables:")
    print("export OPENAI_API_KEY='your-key'")
    print("export ANTHROPIC_API_KEY='your-key'")

### 2.2 Multi-turn Conversations

In [None]:
# Demonstrate conversation with history
try:
    ai = stx.ai.GenAI(provider="openai")
    
    # First turn
    response1 = ai.complete("What is the Fourier transform?")
    print("User: What is the Fourier transform?")
    print(f"AI: {response1['response'][:200]}...\n")
    
    # Second turn (AI remembers context)
    response2 = ai.complete("How is it used in signal processing?")
    print("User: How is it used in signal processing?")
    print(f"AI: {response2['response'][:200]}...\n")
    
    # Show conversation history
    print(f"Conversation length: {len(ai.chat_history.messages)} messages")
    print(f"Total cost so far: ${ai.total_cost:.4f}")
    
    # Clear history when done
    ai.chat_history.clear()
    
except Exception as e:
    print(f"GenAI demo skipped (requires API key): {e}")

### 2.3 Cost Comparison Across Providers

In [None]:
# Compare costs across different providers
prompt = "Write a haiku about machine learning."

providers_to_test = [
    ("openai", "gpt-3.5-turbo"),
    ("openai", "gpt-4"),
    ("anthropic", "claude-3-haiku"),
    ("anthropic", "claude-3-sonnet"),
]

cost_comparison = []

for provider, model in providers_to_test:
    try:
        ai = stx.ai.GenAI(provider=provider, model=model)
        response = ai.complete(prompt)
        cost_comparison.append({
            'Provider': provider,
            'Model': model,
            'Response': response['response'],
            'Cost': response['total_cost'],
            'Tokens': response['total_tokens']
        })
    except Exception as e:
        print(f"Skipping {provider}/{model}: {e}")

if cost_comparison:
    df = pd.DataFrame(cost_comparison)
    print("\nCost Comparison:")
    print(df[['Provider', 'Model', 'Cost', 'Tokens']])
else:
    print("\nCost comparison requires API keys for providers.")

## 3. Classification and Model Evaluation

### 3.1 Classification Reporter

In [None]:
# Generate synthetic classification data
X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, 
                          n_redundant=5, n_classes=3, random_state=42)

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Train a classifier
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# Get predictions and probabilities
y_pred = clf.predict(X_test)
y_proba = clf.predict_proba(X_test)

# Create classification report
reporter = stx.ai.ClassificationReporter()
report = reporter.generate_report(
    y_true=y_test,
    y_pred=y_pred,
    y_proba=y_proba,
    class_names=['Class A', 'Class B', 'Class C']
)

print("Classification Report:")
print(f"Accuracy: {report['accuracy']:.3f}")
print(f"\nPer-class metrics:")
for i, class_name in enumerate(['Class A', 'Class B', 'Class C']):
    print(f"  {class_name}: Precision={report['precision'][i]:.3f}, "
          f"Recall={report['recall'][i]:.3f}, F1={report['f1'][i]:.3f}")

### 3.2 Confusion Matrix Visualization

In [None]:
# Visualize confusion matrix
fig, ax = stx.plt.subplots(figsize=(8, 6))

# Use the AI plotting utilities for confusion matrix
cm = stx.ai.plt.plot_confusion_matrix(
    y_test, y_pred,
    classes=['Class A', 'Class B', 'Class C'],
    ax=ax,
    title='Random Forest Confusion Matrix',
    cmap='Blues'
)

plt.tight_layout()
plt.show()

## 4. Training Utilities

### 4.1 Early Stopping

In [None]:
# Demonstrate early stopping with simulated training
early_stopper = stx.ai.EarlyStopping(patience=5, min_delta=0.001, mode='min')

# Simulate training with validation loss
np.random.seed(42)
epochs = 50
val_losses = []

print("Training with early stopping...")
for epoch in range(epochs):
    # Simulate validation loss (decreasing with noise)
    if epoch < 20:
        val_loss = 1.0 - 0.04 * epoch + 0.05 * np.random.randn()
    else:
        # Plateau and slight increase
        val_loss = 0.2 + 0.001 * (epoch - 20) + 0.05 * np.random.randn()
    
    val_losses.append(val_loss)
    
    # Check early stopping
    if early_stopper.check_stop(val_loss):
        print(f"\nEarly stopping triggered at epoch {epoch}!")
        print(f"Best validation loss: {early_stopper.best_value:.4f}")
        break
    
    if epoch % 5 == 0:
        print(f"Epoch {epoch:3d}: val_loss = {val_loss:.4f}")

# Plot training curve
fig, ax = stx.plt.subplots(figsize=(10, 6))
ax.plot(val_losses, label='Validation Loss')
ax.axvline(x=early_stopper.best_epoch, color='r', linestyle='--', 
           label=f'Best Epoch ({early_stopper.best_epoch})')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training with Early Stopping')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()

### 4.2 Learning Curve Logger

In [None]:
# Create learning curve logger
logger = stx.ai.LearningCurveLogger(
    save_dir="./ai_examples/learning_curves",
    experiment_name="demo_training"
)

# Simulate training with multiple metrics
print("Logging training metrics...")
for epoch in range(30):
    # Simulate metrics
    train_loss = 1.0 * np.exp(-0.1 * epoch) + 0.02 * np.random.randn()
    val_loss = 1.0 * np.exp(-0.08 * epoch) + 0.03 * np.random.randn() + 0.05
    train_acc = 1.0 - train_loss + 0.1
    val_acc = 1.0 - val_loss + 0.1
    
    # Log metrics
    logger.log(epoch, {
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_accuracy': train_acc,
        'val_accuracy': val_acc,
        'learning_rate': 0.001 * (0.95 ** epoch)
    })

# Generate learning curve plots
logger.plot_curves(metrics=['train_loss', 'val_loss', 'train_accuracy', 'val_accuracy'])
print(f"\nLearning curves saved to: {logger.save_dir}")

# Get summary statistics
summary = logger.get_summary()
print("\nTraining Summary:")
print(f"Best validation loss: {summary['best_val_loss']:.4f} at epoch {summary['best_epoch']}")
print(f"Final validation accuracy: {summary['final_val_accuracy']:.4f}")

## 5. Clustering and Dimensionality Reduction

### 5.1 UMAP for Visualization

In [None]:
# Generate high-dimensional clustered data
X_high, y_clusters = make_blobs(n_samples=500, n_features=50, centers=5, 
                               cluster_std=1.0, random_state=42)

# Apply UMAP for visualization
print("Applying UMAP dimensionality reduction...")
try:
    X_umap = stx.ai.clustering.umap_reduce(X_high, n_components=2, 
                                           n_neighbors=15, min_dist=0.1)
    
    # Visualize UMAP results
    fig, ax = stx.plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_umap[:, 0], X_umap[:, 1], c=y_clusters, 
                        cmap='tab10', alpha=0.7, s=50)
    ax.set_xlabel('UMAP Component 1')
    ax.set_ylabel('UMAP Component 2')
    ax.set_title('UMAP Visualization of High-Dimensional Clusters')
    plt.colorbar(scatter, ax=ax, label='Cluster')
    plt.show()
    
except ImportError:
    print("UMAP requires 'umap-learn' package. Install with: pip install umap-learn")
    print("Falling back to PCA visualization...")
    
    # Use PCA as fallback
    X_pca = stx.ai.clustering.pca_reduce(X_high, n_components=2)
    
    fig, ax = stx.plt.subplots(figsize=(10, 8))
    scatter = ax.scatter(X_pca[:, 0], X_pca[:, 1], c=y_clusters, 
                        cmap='tab10', alpha=0.7, s=50)
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_title('PCA Visualization of High-Dimensional Clusters')
    plt.colorbar(scatter, ax=ax, label='Cluster')
    plt.show()

## 6. Custom Loss Functions

In [None]:
# Demonstrate multi-task loss
import torch
import torch.nn as nn

# Create multi-task loss
task_losses = {
    'classification': nn.CrossEntropyLoss(),
    'regression': nn.MSELoss(),
    'reconstruction': nn.L1Loss()
}

# Initialize multi-task loss with learnable weights
mtl_loss = stx.ai.loss.MultiTaskLoss(task_losses, init_weights=None)

# Simulate some predictions and targets
batch_size = 32
outputs = {
    'classification': torch.randn(batch_size, 10),  # 10 classes
    'regression': torch.randn(batch_size, 1),
    'reconstruction': torch.randn(batch_size, 100)
}

targets = {
    'classification': torch.randint(0, 10, (batch_size,)),
    'regression': torch.randn(batch_size, 1),
    'reconstruction': torch.randn(batch_size, 100)
}

# Compute multi-task loss
total_loss, task_losses_dict = mtl_loss(outputs, targets)

print("Multi-Task Loss Computation:")
print(f"Total Loss: {total_loss.item():.4f}")
print("\nIndividual Task Losses:")
for task, loss in task_losses_dict.items():
    print(f"  {task}: {loss.item():.4f}")

print("\nLearned Task Weights:")
weights = mtl_loss.get_weights()
for task, weight in weights.items():
    print(f"  {task}: {weight:.4f}")

## 7. Model Optimization Utilities

In [None]:
# Demonstrate optimizer utilities
import torch.nn as nn

# Create a simple model
model = nn.Sequential(
    nn.Linear(10, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 1)
)

# Get optimizer with SciTeX utilities
optimizer = stx.ai.optim.get_optimizer(
    model.parameters(),
    optimizer_name='AdamW',
    lr=0.001,
    weight_decay=0.01
)

print("Optimizer Configuration:")
print(f"Type: {type(optimizer).__name__}")
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
print(f"Weight Decay: {optimizer.param_groups[0]['weight_decay']}")

# Available optimizers
print("\nAvailable Optimizers:")
available_optimizers = ['SGD', 'Adam', 'AdamW', 'RMSprop', 'Adagrad', 'Adadelta']
for opt_name in available_optimizers:
    print(f"  - {opt_name}")

## 8. Sampling Utilities

In [None]:
# Demonstrate undersampling for imbalanced datasets
# Create imbalanced dataset
n_samples = 1000
n_minority = 100

X_majority = np.random.randn(n_samples - n_minority, 10)
y_majority = np.zeros(n_samples - n_minority)

X_minority = np.random.randn(n_minority, 10) + 2  # Shifted distribution
y_minority = np.ones(n_minority)

X_imbalanced = np.vstack([X_majority, X_minority])
y_imbalanced = np.hstack([y_majority, y_minority])

print("Original dataset:")
print(f"Class 0: {np.sum(y_imbalanced == 0)} samples")
print(f"Class 1: {np.sum(y_imbalanced == 1)} samples")
print(f"Imbalance ratio: {np.sum(y_imbalanced == 0) / np.sum(y_imbalanced == 1):.1f}:1")

# Apply undersampling
X_balanced, y_balanced = stx.ai.sampling.undersample(
    X_imbalanced, y_imbalanced,
    sampling_strategy='auto',
    random_state=42
)

print("\nAfter undersampling:")
print(f"Class 0: {np.sum(y_balanced == 0)} samples")
print(f"Class 1: {np.sum(y_balanced == 1)} samples")
print(f"New ratio: {np.sum(y_balanced == 0) / np.sum(y_balanced == 1):.1f}:1")

## 9. Feature Extraction with Vision Transformers

In [None]:
# Demonstrate feature extraction (conceptual - requires actual image data)
print("Vision Transformer Feature Extraction:")
print("\nSciTeX provides ViT-based feature extraction for images:")
print("""\n# Example usage:
from scitex.ai.feature_extraction import vit

# Load pretrained ViT
feature_extractor = vit.load_pretrained_vit('vit_base_patch16_224')

# Extract features from images
image_batch = torch.randn(4, 3, 224, 224)  # Batch of 4 RGB images
features = feature_extractor(image_batch)

# Features shape: [batch_size, feature_dim]
print(f"Extracted features shape: {features.shape}")
""")

print("\nUse cases:")
print("- Transfer learning for image classification")
print("- Image similarity search")
print("- Feature-based clustering of images")
print("- Input features for downstream tasks")

## 10. Summary and Best Practices

### Key Takeaways

1. **Unified AI Interface**: SciTeX provides a single interface for multiple AI providers
2. **Cost Awareness**: Always monitor costs when using AI models
3. **Training Utilities**: Use early stopping and learning curve logging for better training
4. **Classification Tools**: Comprehensive reporting for model evaluation
5. **Visualization**: Built-in support for ML-specific visualizations

### Best Practices

1. **API Key Management**:
   ```python
   # Use environment variables
   os.environ['OPENAI_API_KEY'] = 'your-key'
   # Never hardcode keys in notebooks
   ```

2. **Error Handling**:
   ```python
   try:
       response = ai.complete(prompt)
   except Exception as e:
       logger.error(f"AI request failed: {e}")
       # Implement fallback behavior
   ```

3. **Cost Optimization**:
   ```python
   # Use cheaper models for testing
   ai = stx.ai.GenAI(provider="openai", model="gpt-3.5-turbo")
   # Clear history when not needed
   ai.chat_history.clear()
   ```

4. **Reproducibility**:
   ```python
   # Set random seeds
   stx.repro.set_seed(42)
   # Log all parameters
   logger.log_params(params)
   ```

In [None]:
# Clean up
print("Tutorial completed!")
print("\nNext steps:")
print("1. Set up API keys for AI providers")
print("2. Explore advanced GenAI features (streaming, vision models)")
print("3. Try custom neural network architectures")
print("4. Implement production ML pipelines with SciTeX")