In [None]:
from ultralytics import YOLO
import torch
import wandb

In [None]:
# Sweep Config
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'metrics/mAP50-95(B)', 
        'goal': 'maximize'   
    },
    'parameters': {
        'lr0': {'min': 0.0001, 'max': 0.01},
        'optimizer': {'values': ['SGD', 'AdamW']},
        
        'batch': {'values': [32, 64, 128]}, 
    
        'epochs': {'value': 30} 
    }
}

def on_fit_epoch_end(trainer):
    if trainer.epoch == 0:
        print("\nAVAILABLE METRICS:", list(trainer.metrics.keys()))
    if wandb.run:
        wandb.log(trainer.metrics)

def train_sweep(config=None):
    with wandb.init(config=config, project="YOLOv8-Mushroom", job_type="training"):
        config = wandb.config
        model = YOLO('yolov8x.pt') 

        model.add_callback("on_fit_epoch_end", on_fit_epoch_end)

        model.train(
            data='./data.yaml',
            project='runs/detect',
            
            # Sweep Params
            epochs=config.epochs,
            batch=config.batch,
            optimizer=config.optimizer,
            lr0=config.lr0,
            
            imgsz=640,  
            device='0',
            workers=16,     
            verbose=False,
            amp=True        
        )

# Sweep
sweep_id = wandb.sweep(sweep_config, project="YOLOv8-Mushroom")

if __name__ == '__main__':
    wandb.agent(sweep_id, train_sweep, count=15)

In [None]:
# Initialize WandB
run = wandb.init(
    project="YOLOv8-Mushroom", 
    name="Yolo-xlarge-best-param", 
    job_type="training"
)

def on_fit_epoch_end(trainer):
    if wandb.run:
        wandb.log(trainer.metrics)

# Load Yolo
model = YOLO('yolov8x.pt') 

# Attach the Callback
model.add_callback("on_fit_epoch_end", on_fit_epoch_end)

# Train with the best param after sweep

results = model.train(
    data='./data.yaml',

    optimizer='SGD',
    lr0=0.004427,       
    batch=16,          

    epochs=150,        
    patience=50,       
    
    imgsz=640,
    device='0',
    workers=16,        
    project='runs/detect',
    name='mushroom_final',

    val=True,          
    save=True,          
    verbose=False      
)

wandb.finish()
