# StrandWeaver Model Training on Google Colab

This notebook trains lightweight ML models for StrandWeaver v0.2 using free Colab resources.

**What we'll do:**
1. Clone StrandWeaver from GitHub
2. Install dependencies
3. Generate lightweight training data (~10-20 genomes)
4. Train 2-3 models (EdgeAI, DiploidAI)
5. Save models to Google Drive

**Time estimate:** 2-3 hours total

**GPU:** Enable GPU in Runtime > Change runtime type > T4 GPU

## Step 1: Setup Environment

In [None]:
# Clone StrandWeaver
!git clone https://github.com/pgrady1322/strandweaver.git
%cd strandweaver

In [None]:
# Install dependencies
!pip install -q -e .[ai]

# Verify installation
import strandweaver
print(f"StrandWeaver version: {strandweaver.__version__}")

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 2: Mount Google Drive (for saving models)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Step 3: Generate Training Data

We'll use the "simple" scenario (10 genomes × 100kb) for quick testing.

For production models, use "balanced" (100 genomes × 1Mb) - takes ~4-6 hours.

In [None]:
# Generate lightweight training data
!python scripts/generate_assembly_training_data.py \
    --scenario simple \
    --output-dir training_data/simple \
    --num-workers 2

In [None]:
# Check generated data
!ls -lh training_data/simple/

## Step 4: Train EdgeAI Classifier (XGBoost)

This is the easiest model to train - takes ~30-60 minutes on simple data.

In [None]:
!python scripts/train_models/train_edge_classifier.py \
    --data-dir training_data/simple \
    --output models/edge_classifier_v0.1_colab.model

## Step 5: Train DiploidAI Classifier (PyTorch MLP)

This trains faster with GPU - takes ~30-45 minutes.

In [None]:
!python scripts/train_models/train_diploid_classifier.py \
    --data-dir training_data/simple \
    --output models/diploid_classifier_v0.1_colab.pth \
    --epochs 50 \
    --batch-size 128

## Step 6: Save Models to Google Drive

In [None]:
# Create directory in Google Drive
!mkdir -p /content/drive/MyDrive/StrandWeaver_Models

# Copy trained models
!cp models/edge_classifier_v0.1_colab.model /content/drive/MyDrive/StrandWeaver_Models/
!cp models/edge_classifier_v0.1_colab.json /content/drive/MyDrive/StrandWeaver_Models/
!cp models/diploid_classifier_v0.1_colab.pth /content/drive/MyDrive/StrandWeaver_Models/
!cp models/diploid_classifier_v0.1_colab.json /content/drive/MyDrive/StrandWeaver_Models/

print("✅ Models saved to Google Drive!")

## Step 7: Test Models (Optional)

In [None]:
# Load and test edge classifier
import xgboost as xgb

model = xgb.Booster()
model.load_model('models/edge_classifier_v0.1_colab.model')

print("✅ Edge classifier loaded successfully!")
print(f"Number of boosting rounds: {model.num_boosted_rounds()}")

In [None]:
# Load and test diploid classifier
import torch
import json

# Load model architecture from metadata
with open('models/diploid_classifier_v0.1_colab.json') as f:
    metadata = json.load(f)

from strandweaver.assembly_core.diploid_disentangler_module import DiploidClassifier

model = DiploidClassifier(input_dim=metadata['input_dim'])
model.load_state_dict(torch.load('models/diploid_classifier_v0.1_colab.pth'))

print("✅ Diploid classifier loaded successfully!")
print(f"Input dim: {metadata['input_dim']}, Accuracy: {metadata['accuracy']:.3f}")

## Next Steps

1. **Download models from Google Drive** to your local machine
2. **Copy to StrandWeaver:** Place in `strandweaver/ai/models/trained_models/`
3. **Test locally:** Run assembly with trained models
4. **Train more models:** Repeat with "balanced" scenario for production quality

### For Production Models:

```python
# Use balanced scenario (100 genomes × 1Mb)
!python scripts/generate_assembly_training_data.py \
    --scenario balanced \
    --output-dir training_data/balanced \
    --num-workers 4

# Train with more data (4-6 hours)
!python scripts/train_models/train_edge_classifier.py \
    --data-dir training_data/balanced \
    --output models/edge_classifier_v0.2_production.model
```

---

**Need help?** Open an issue on GitHub: https://github.com/pgrady1322/strandweaver/issues