# Cocoa Pod Disease: Hybrid CNN–ViT + YOLOv8 Severity (Colab)

This notebook trains and runs the pipeline:
- **Classification**: CNN / ViT / Concat / Attention Fusion (+ EMA optional)
- **Detection**: YOLOv8 lesion (and optionally pod)
- **Severity**: `lesion_area / pod_area`

Dataset is required (not included).

In [None]:
# === 1) Mount Drive (optional) ===
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# === 2) (Option A) Unzip this project (if you uploaded cocoa_hybrid_project.zip) ===
!unzip -q cocoa_hybrid_project.zip -d /content/cocoa_hybrid_project
%cd /content/cocoa_hybrid_project


In [None]:
# === 3) Install dependencies ===
!pip -q install -r requirements.txt


## Prepare dataset folders

### Classification dataset structure
```
data_cls/
  train/Healthy, BPR, FPR
  val/Healthy, BPR, FPR
  test/Healthy, BPR, FPR   (optional)
```

### YOLO detection dataset structure
```
data_det/
  images/train, images/val
  labels/train, labels/val
  data.yaml
```
Recommended YOLO classes: `[pod, lesion]`

In [None]:
# === 4) Train classification ===
# Edit --data_dir to where your dataset is (Drive path or /content path)
!python -m src.train_cls \
  --data_dir data_cls \
  --variant attn \
  --epochs 20 \
  --batch_size 32 \
  --lr 3e-4 \
  --use_ema


In [None]:
# === 5) Train YOLOv8 detection ===
# Make sure your data.yaml points to the correct dataset path.
!python -m src.train_yolo \
  --data_yaml data_det/data.yaml \
  --model yolov8n.pt \
  --epochs 100 \
  --imgsz 640


In [None]:
# === 6) Run end-to-end inference + severity ===
!python -m src.infer_cls_and_severity \
  --cls_ckpt runs_cls/best.pt \
  --yolo_ckpt runs/detect/train/weights/best.pt \
  --image_path path/to/one_image.jpg \
  --out_path prediction.jpg

from PIL import Image
Image.open('prediction.jpg')


In [None]:
# === 7) Explainability (optional): Grad-CAM + ViT attention rollout ===
!python -m src.xai --cls_ckpt runs_cls/best.pt --image_path path/to/one_image.jpg --mode both --out_dir xai_out

from PIL import Image
display(Image.open('xai_out/gradcam_cnn.jpg'))
display(Image.open('xai_out/vit_attention_rollout.jpg'))
