# Improved Training (rect, cache=disk, optimizer control, copy-paste)
This version uses better defaults for small datasets and Colab T4, adds an optional fine-tuning stage, and keeps exports organized.

# Train YOLO on Colab (T4) and export best.pt

> This notebook mounts Google Drive, installs Ultralytics, verifies your project paths, runs training using `tms-fault-detection-model/train_yolo.py`, and copies `best.pt` plus key artifacts into an `exports/` folder for easy download.

Expected Drive layout:
- `/content/drive/MyDrive/Software design project/Annotated_dataset/data.yaml`
- `/content/drive/MyDrive/Software design project/tms-fault-detection-model/train_yolo.py`

Tip: In Colab, set Runtime -> Change runtime type -> Hardware accelerator: GPU (T4).

In [None]:
# 1) Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 2) Install Ultralytics (YOLOv8/YOLO11)
%pip install -q ultralytics

In [None]:
# 3) Verify folder structure and GPU availability
import os, torch
ROOT = "/content/drive/MyDrive/Software design project"
print("GPU available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

script_ok = os.path.exists(f"{ROOT}/tms-fault-detection-model/train_yolo.py")
data_ok = os.path.exists(f"{ROOT}/Annotated_dataset/data.yaml")
print("Training script exists:", script_ok)
print("Dataset YAML exists:", data_ok)
assert script_ok and data_ok, "Drive structure mismatch. Ensure both paths exist as shown above."

In [None]:
# 4) Run training (adjust epochs/batch/imgsz/name as needed)
!python "/content/drive/MyDrive/Software design project/tms-fault-detection-model/train_yolo.py" \
  --weights yolo11s.pt \
  --epochs 120 \
  --batch 16 \
  --imgsz 640 \
  --name yolo11s_colab_t4_rect_cache \
  --patience 30 \
  --optimizer AdamW \
  --lr0 0.0015 \
  --weight_decay 0.0005 \
  --mosaic 0.2 \
  --mixup 0.0 \
  --copy_paste 0.2 \
  --rect True \
  --cache disk \
  --workers 2 \
  --freeze 10

In [None]:
# 5) Locate best.pt and copy it to a stable path for easy download
import glob, os, shutil
ROOT = "/content/drive/MyDrive/Software design project"
candidates = glob.glob(f"{ROOT}/tms-fault-detection-model/runs/*/weights/best.pt")
if not candidates:
    raise FileNotFoundError("No best.pt found. Check the training logs for errors.")
candidates.sort(key=os.path.getmtime)
best = candidates[-1]
print("Latest best.pt:", best)

export_dir = f"{ROOT}/tms-fault-detection-model/exports"
os.makedirs(export_dir, exist_ok=True)
dst = f"{export_dir}/best.pt"
shutil.copy2(best, dst)
print("Copied to:", dst)

# Optional: also copy summary plots/metrics from the same run
run_dir = os.path.dirname(os.path.dirname(best))  # .../runs/<name>
for fname in ["results.png", "results.csv", "confusion_matrix.png", "PR_curve.png", "F1_curve.png"]:
    src = os.path.join(run_dir, fname)
    if os.path.exists(src):
        shutil.copy2(src, os.path.join(export_dir, fname))
        print("Exported:", fname)

## What to download after training
- Primary: `tms-fault-detection-model/exports/best.pt`
- Optional diagnostics: `results.png`, `results.csv`, `confusion_matrix.png`, `PR_curve.png`, `F1_curve.png` in `tms-fault-detection-model/exports/` (copied from the run folder).

In [None]:
# 6) (Optional) Download best.pt to your local machine (from Colab)
from google.colab import files
files.download("/content/drive/MyDrive/Software design project/tms-fault-detection-model/exports/best.pt")

In [None]:
# 7) (Optional) Fine-tune from best with lower LR for a few more epochs
import glob, os, subprocess
ROOT = "/content/drive/MyDrive/Software design project"
candidates = glob.glob(f"{ROOT}/tms-fault-detection-model/runs/*/weights/best.pt")
if candidates:
    candidates.sort(key=os.path.getmtime)
    best = candidates[-1]
    print("Resuming from:", best)
    cmd = [
        "python", f"{ROOT}/tms-fault-detection-model/train_yolo.py",
        "--weights", best,
        "--epochs", "40",
        "--batch", "16",
        "--imgsz", "640",
        "--name", "yolo11s_colab_t4_finetune",
        "--patience", "20",
        "--optimizer", "AdamW",
        "--lr0", "0.0005",
        "--weight_decay", "0.0005",
        "--mosaic", "0.0",
        "--mixup", "0.0",
        "--copy_paste", "0.0",
        "--rect", "True",
        "--cache", "disk",
        "--workers", "2"
    ]
    print("Running:", " ".join(cmd))
    subprocess.run(cmd, check=True)
else:
    print("No best.pt found to fine-tune from. Skipping.")

## Quick inference and visualization
This cell loads the latest `best.pt`, runs inference on a few sample images, displays them inline, and saves annotated images into `tms-fault-detection-model/exports/infer/`.
- Adjust `SAMPLES` to point to any maintenance images you want to preview.

In [None]:
# 8) Run inference on a few images and visualize
import os, glob, shutil
from pathlib import Path
import matplotlib.pyplot as plt
from ultralytics import YOLO

ROOT = "/content/drive/MyDrive/Software design project"
# 1) Find latest best.pt
candidates = glob.glob(f"{ROOT}/tms-fault-detection-model/runs/*/weights/best.pt")
assert candidates, "No best.pt found. Train first, then re-run this cell."
candidates.sort(key=os.path.getmtime)
best = candidates[-1]
print("Using weights:", best)

# 2) Define sample images for preview (adjust these)
SAMPLES = [
    f"{ROOT}/Annotated_dataset/valid/images",  # folder allowed
    # You can also put specific image paths here, e.g.:
    # f"{ROOT}/test/images/T1_faulty_017_jpg.rf.343c025053fbbe4412ddb9db5b2a2517.jpg",
    # f"{ROOT}/test/images/T13_normal_002_jpg.rf.5fe515713bbcde73bb886d4cd1530855.jpg",
]

# 3) Collect image paths (from folders or files), limit to N
def collect_images(paths, limit=6):
    out = []
    for p in paths:
        if os.path.isdir(p):
            out.extend(glob.glob(os.path.join(p, "*.jpg")))
            out.extend(glob.glob(os.path.join(p, "*.png")))
            out.extend(glob.glob(os.path.join(p, "*.jpeg")))
        elif os.path.isfile(p):
            out.append(p)
    # unique + keep order
    seen, uniq = set(), []
    for x in out:
        if x not in seen:
            seen.add(x); uniq.append(x)
    return uniq[:limit]

images = collect_images(SAMPLES, limit=6)
assert images, "No images found for inference. Update SAMPLES to valid image paths."
print("Images:", images)

# 4) Run prediction and visualize
model = YOLO(best)
export_dir = f"{ROOT}/tms-fault-detection-model/exports/infer"
os.makedirs(export_dir, exist_ok=True)

results = model.predict(images=images, conf=0.25, iou=0.45, imgsz=640, device=0 if model.device.type=='cuda' else 'cpu', save=False, verbose=False)

for img_path, res in zip(images, results):
    # Save an annotated image via Ultralytics rendering
    plotted = res.plot()
    out_path = os.path.join(export_dir, Path(img_path).name)
    import cv2
    cv2.imwrite(out_path, plotted)
    print("Saved:", out_path)
    # Show inline
    plt.figure(figsize=(6,6))
    plt.imshow(plotted[..., ::-1])
    plt.title(Path(img_path).name)
    plt.axis('off')
plt.show()