In [None]:
# =========================================================
# INSTALA√á√ÉO
# =========================================================
!pip install ultralytics pillow

print("‚úÖ Depend√™ncias instaladas!")

In [None]:
from ultralytics import YOLO
from google.colab import drive
from pathlib import Path
import shutil
import os
import yaml
import glob
import xml.etree.ElementTree as ET
from collections import Counter
import zipfile

# =========================================================
# 0. Mount Google Drive
# =========================================================
drive.mount('/content/drive')

DRIVE_ROOT = "/content/drive/MyDrive/colab"
PROJECT_NAME = "cloud-arch-security-mvp"

DRIVE_PROJECT = f"{DRIVE_ROOT}/{PROJECT_NAME}"
DRIVE_CHECKPOINTS = f"{DRIVE_PROJECT}/checkpoints"
DRIVE_DATASET_ZIP = f"{DRIVE_PROJECT}/kaggle_dataset_cache/dataset_ready.zip"  # ZIP preparado com prepare_dataset.py

CONTENT_PROJECT = "/content/yolo-project"

os.makedirs(DRIVE_CHECKPOINTS, exist_ok=True)

# =========================================================
# 1. Carregar Dataset do Drive (via ZIP - muito mais r√°pido!)
# =========================================================
print("üì• Carregando dataset do Google Drive...")

# Limpa ambiente anterior
if Path(CONTENT_PROJECT).exists():
    shutil.rmtree(CONTENT_PROJECT)
os.makedirs(CONTENT_PROJECT, exist_ok=True)

RAW_DATA_PATH = f"{CONTENT_PROJECT}/raw_data"

# Verifica se o ZIP existe
if not os.path.exists(DRIVE_DATASET_ZIP):
    print("‚ùå Dataset ZIP n√£o encontrado!")
    print(f"   Esperado em: {DRIVE_DATASET_ZIP}")
    print("\nüìã Instru√ß√µes:")
    print("   1. No seu PC, execute: python prepare_dataset.py")
    print("   2. Fa√ßa upload de 'dataset_ready.zip' para o Google Drive em:")
    print(f"      {DRIVE_DATASET_ZIP}")
    raise Exception("Dataset ZIP n√£o encontrado no Drive")

# Copia ZIP para o Colab (1 arquivo = r√°pido!)
print("   üì¶ Copiando ZIP para ambiente local...")
zip_size_mb = os.path.getsize(DRIVE_DATASET_ZIP) / (1024 * 1024)
print(f"      Tamanho: {zip_size_mb:.1f} MB")

local_zip = "/content/dataset.zip"
shutil.copy(DRIVE_DATASET_ZIP, local_zip)
print("   ‚úÖ ZIP copiado!")

# Descompacta localmente (SSD do Colab = muito r√°pido!)
print("   üìÇ Descompactando...")
os.makedirs(RAW_DATA_PATH, exist_ok=True)

with zipfile.ZipFile(local_zip, 'r') as zip_ref:
    zip_ref.extractall(RAW_DATA_PATH)

# Remove ZIP para liberar espa√ßo
os.remove(local_zip)

# Verifica se extraiu corretamente (pode ter subpasta)
extracted_items = os.listdir(RAW_DATA_PATH)
if len(extracted_items) == 1 and os.path.isdir(f"{RAW_DATA_PATH}/{extracted_items[0]}"):
    # Se extraiu para uma subpasta, move os arquivos
    subfolder = f"{RAW_DATA_PATH}/{extracted_items[0]}"
    for item in os.listdir(subfolder):
        shutil.move(f"{subfolder}/{item}", RAW_DATA_PATH)
    os.rmdir(subfolder)

file_count = len(os.listdir(RAW_DATA_PATH))
print(f"‚úÖ Dataset carregado! ({file_count} arquivos)")

os.chdir(CONTENT_PROJECT)

# =========================================================
# 2. Converter Pascal VOC (XML) para YOLO Format
# =========================================================
print("\nüîÑ Convertendo Pascal VOC para YOLO format...")

RAW_DATA = Path("raw_data")

# Coleta todas as classes do dataset
all_classes = set()
xml_files = list(RAW_DATA.glob("**/*.xml"))  # Busca recursiva
print(f"   Encontrados {len(xml_files)} arquivos XML")

if len(xml_files) == 0:
    print("‚ùå Nenhum arquivo XML encontrado!")
    print(f"   Verificando conte√∫do de {RAW_DATA}:")
    for i, item in enumerate(RAW_DATA.iterdir()):
        print(f"      - {item.name}")
        if i > 20:
            print("      ... (mais arquivos)")
            break
    raise Exception("Dataset inv√°lido - sem arquivos XML")

for xml_file in xml_files:
    try:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for obj in root.findall('object'):
            class_name = obj.find('name').text
            all_classes.add(class_name)
    except Exception as e:
        pass  # Ignora erros silenciosamente

all_classes = sorted(list(all_classes))
print(f"   Total de classes encontradas: {len(all_classes)}")

# =========================================================
# 3. Mapear para 14 Categorias STRIDE
# =========================================================
print("\nüìä Mapeando para categorias STRIDE...")

# Mapeamento COMPLETO incluindo nomes com prefixos AWS/Azure/GCP
CATEGORY_MAPPING = {
    'compute': [
        'EC2', 'Lambda', 'EKS', 'Fargate', 'Container', 'ECS',
        'App Service', 'Virtual Machine', 'VM', 'Compute Engine',
        'Cloud Run', 'App Engine', 'GKE', 'AKS', 'Kubernetes',
        'Elastic Beanstalk', 'Batch', 'Lightsail', 'EMR',
        # Nomes completos do dataset
        'aws_amazon_ec2', 'aws_amazon_ec2_instance', 'aws_ec2',
        'aws_auto_scaling', 'aws_autoscaling', 'aws_elastic_beanstalk',
        'aws_amazon_elastic_container_service', 'aws_amazon_eks',
        'azure_virtual_machines', 'azure_app_services', 'azure_vm',
        'azure_container_instances', 'azure_kubernetes_service',
        'gcp_compute_engine', 'gcp_gke', 'gcp_cloud_run',
    ],
    
    'database': [
        'RDS', 'DynamoDB', 'Aurora', 'DocumentDB', 'ElastiCache',
        'Cosmos DB', 'SQL Database', 'Cloud SQL', 'Firestore',
        'BigQuery', 'Redshift', 'Neptune', 'Cloud Spanner',
        'Managed Database', 'Database', 'DB', 'Redis', 'Memcached',
        # Nomes completos do dataset
        'aws_amazon_rds', 'aws_amazon_dynamodb', 'aws_amazon_aurora',
        'aws_amazon_elasticache', 'aws_amazon_redshift',
        'azure_sql_database', 'azure_cosmos_db', 'azure_sql_databases',
        'gcp_cloud_sql', 'gcp_cloud_spanner', 'gcp_firestore',
    ],
    
    'storage': [
        'S3', 'EBS', 'EFS', 'Glacier', 'Storage', 'Blob Storage',
        'Cloud Storage', 'File Storage', 'Azure Storage', 'GCS',
        'Backup', 'Archive', 'Data Lake',
        # Nomes completos do dataset
        'aws_amazon_s3', 'aws_amazon_simple_storage_service',
        'aws_amazon_elastic_block_store', 'aws_elastic_block_store',
        'aws_elastic_block_store_volume', 'aws_elactic_file_system',
        'aws_amazon_glacier', 'aws_amazon_efs',
        'azure_blob_storage', 'azure_storage_accounts', 'azure_files',
        'gcp_cloud_storage', 'gcp_persistent_disk',
    ],
    
    'network': [
        'VPC', 'Virtual Network', 'VNet', 'Subnet', 'Gateway',
        'Load Balancer', 'ALB', 'NLB', 'ELB', 'CloudFront',
        'CDN', 'Route 53', 'DNS', 'VPN', 'Direct Connect',
        'ExpressRoute', 'Cloud Interconnect', 'NAT', 'Firewall',
        'Network', 'Internet', 'Internet Gateway', 'Transit Gateway',
        # Nomes completos do dataset
        'aws_amazon_vpc', 'aws_amazon_virtual_private_cloud',
        'aws_virtual_private_cloud', 'aws_amazon_route_53',
        'aws_route_53_hosted_zone', 'aws_amazon_cloudfront',
        'aws_elastic_load_balancing', 'aws_elb', 'aws_alb', 'aws_nlb',
        'aws_internet_gateway', 'aws_nat_gateway', 'aws_transit_gateway',
        'aws_region', 'aws_availability_zone',
        'azure_virtual_network', 'azure_load_balancer', 'azure_cdn',
        'azure_application_gateway', 'azure_traffic_manager',
        'azure_expressroute', 'azure_vpn_gateway', 'azure_firewall',
        'gcp_vpc', 'gcp_cloud_load_balancing', 'gcp_cloud_cdn',
    ],
    
    'security': [
        'IAM', 'Identity', 'Cognito', 'WAF', 'Shield', 'GuardDuty',
        'Security Hub', 'Key Vault', 'KMS', 'Secrets Manager',
        'Certificate', 'Azure AD', 'Cloud Identity', 'SSO',
        # Nomes completos do dataset
        'aws_amazon_iam', 'aws_iam', 'aws_amazon_cognito',
        'aws_key_management_service', 'aws_amazon_kms',
        'aws_secrets_manager', 'aws_waf', 'aws_shield',
        'aws_amazon_guardduty', 'aws_security_hub',
        'azure_key_vault', 'azure_active_directory', 'azure_ad',
        'azure_security_center', 'azure_sentinel',
        'gcp_cloud_iam', 'gcp_secret_manager',
    ],
    
    'api_gateway': [
        'API Gateway', 'API Management', 'Apigee', 'AppSync',
        'API', 'Gateway', 'Endpoints',
        # Nomes completos do dataset
        'aws_amazon_api_gateway', 'aws_api_gateway',
        'azure_api_management', 'azure_api_apps',
        'gcp_apigee', 'gcp_cloud_endpoints',
    ],
    
    'messaging': [
        'SQS', 'SNS', 'SES', 'EventBridge', 'Service Bus', 'Pub/Sub',
        'Kinesis', 'Event Hub', 'MQ', 'Queue', 'Topic',
        'Notification', 'Event Grid', 'Email',
        # Nomes completos do dataset
        'aws_amazon_sqs', 'aws_amazon_sns', 'aws_amazon_ses',
        'aws_amazon_eventbridge', 'aws_amazon_kinesis',
        'aws_amazon_mq', 'aws_simple_queue_service',
        'aws_simple_notification_service', 'aws_simple_email_service',
        'azure_service_bus', 'azure_event_hubs', 'azure_event_grid',
        'azure_notification_hubs', 'azure_queue_storage',
        'gcp_pub_sub', 'gcp_cloud_tasks',
    ],
    
    'monitoring': [
        'CloudWatch', 'CloudTrail', 'Monitor', 'Log Analytics', 'Stackdriver',
        'Cloud Monitoring', 'X-Ray', 'Application Insights',
        'Logging', 'Metrics', 'Trace', 'Grafana', 'Prometheus', 'Audit',
        # Nomes completos do dataset
        'aws_amazon_cloudwatch', 'aws_cloudwatch',
        'aws_cloud_trail', 'aws_cloudtrail', 'aws_amazon_cloudtrail',
        'aws_x-ray', 'aws_xray',
        'azure_monitor', 'azure_application_insights', 'azure_log_analytics',
        'gcp_cloud_monitoring', 'gcp_cloud_logging', 'gcp_cloud_trace',
    ],
    
    'identity': [
        'User', 'Client', 'Application', 'Service Principal',
        'OAuth', 'OIDC', 'SAML', 'Directory', 'Active Directory',
    ],
    
    'ml_ai': [
        'SageMaker', 'Machine Learning', 'AI Platform', 'Databricks',
        'Cognitive Services', 'Vertex AI', 'Rekognition', 'Comprehend',
        'Textract', 'Vision', 'Speech', 'Natural Language',
        # Nomes completos do dataset
        'aws_amazon_sagemaker', 'aws_sagemaker',
        'azure_machine_learning', 'azure_cognitive_services',
        'gcp_vertex_ai', 'gcp_automl',
    ],
    
    'devops': [
        'CodePipeline', 'CodeBuild', 'CodeDeploy', 'DevOps',
        'Cloud Build', 'Artifact Registry', 'Container Registry',
        'ECR', 'ACR', 'GCR', 'CI/CD', 'Pipeline', 'Build',
        # Nomes completos do dataset
        'aws_codepipeline', 'aws_codebuild', 'aws_codedeploy',
        'aws_codecommit', 'aws_amazon_ecr',
        'aws_cloudformation', 'aws_cloudformation_template',
        'azure_devops', 'azure_container_registry', 'azure_pipelines',
        'gcp_cloud_build', 'gcp_artifact_registry',
    ],
    
    'serverless': [
        'Lambda', 'Functions', 'Azure Functions', 'Cloud Functions',
        'Step Functions', 'Logic Apps', 'Workflows',
        # Nomes completos do dataset
        'aws_amazon_lambda', 'aws_lambda', 'aws_step_functions',
        'azure_functions', 'azure_function_apps', 'azure_logic_apps',
        'gcp_cloud_functions',
    ],
    
    'analytics': [
        'Athena', 'BigQuery', 'Synapse', 'Data Factory',
        'Glue', 'Dataflow', 'EMR', 'HDInsight', 'Dataproc',
        'Analytics', 'Data Warehouse', 'ETL',
        # Nomes completos do dataset
        'aws_amazon_athena', 'aws_amazon_emr', 'aws_glue',
        'aws_amazon_kinesis_data_firehose', 'aws_amazon_quicksight',
        'azure_synapse_analytics', 'azure_data_factories',
        'azure_data_lake', 'azure_hdinsight', 'azure_stream_analytics',
        'gcp_bigquery', 'gcp_dataflow', 'gcp_dataproc',
    ]
}

# Cria mapeamento inverso (case-insensitive)
name_to_category = {}
for category, keywords in CATEGORY_MAPPING.items():
    for keyword in keywords:
        name_to_category[keyword.lower()] = category

def get_category(class_name):
    """Mapeia nome de classe para categoria."""
    class_lower = class_name.lower()
    
    # Busca exata primeiro
    if class_lower in name_to_category:
        return name_to_category[class_lower]
    
    # Busca parcial - verifica se alguma keyword est√° contida no nome
    for keyword, category in name_to_category.items():
        # Ignora keywords muito curtas para evitar falsos positivos
        if len(keyword) < 3:
            continue
        if keyword in class_lower:
            return category
    
    # Busca por padr√µes comuns
    patterns = {
        'ec2': 'compute',
        'lambda': 'serverless',
        's3': 'storage',
        'rds': 'database',
        'vpc': 'network',
        'iam': 'security',
        'sqs': 'messaging',
        'sns': 'messaging',
        'cloudwatch': 'monitoring',
        'cloudtrail': 'monitoring',
        'api_gateway': 'api_gateway',
        'load_balanc': 'network',
        'elastic_block': 'storage',
        'route_53': 'network',
        'route53': 'network',
        'cloudfront': 'network',
        'cognito': 'security',
        'kms': 'security',
        'key_management': 'security',
        'secrets': 'security',
        'dynamodb': 'database',
        'elasticache': 'database',
        'redshift': 'database',
        'aurora': 'database',
        'glacier': 'storage',
        'efs': 'storage',
        'ebs': 'storage',
        'kinesis': 'messaging',
        'eventbridge': 'messaging',
        'event_hub': 'messaging',
        'service_bus': 'messaging',
        'sagemaker': 'ml_ai',
        'machine_learning': 'ml_ai',
        'codepipeline': 'devops',
        'codebuild': 'devops',
        'cloudformation': 'devops',
        'step_function': 'serverless',
        'logic_app': 'serverless',
        'function_app': 'serverless',
        'data_factor': 'analytics',
        'synapse': 'analytics',
        'athena': 'analytics',
        'glue': 'analytics',
        'bigquery': 'analytics',
        'virtual_machine': 'compute',
        'app_service': 'compute',
        'container': 'compute',
        'kubernetes': 'compute',
        'eks': 'compute',
        'aks': 'compute',
        'gke': 'compute',
    }
    
    for pattern, category in patterns.items():
        if pattern in class_lower:
            return category
    
    return 'other'

# Mapeia todas as classes
class_to_category = {cls: get_category(cls) for cls in all_classes}

# Mostra distribui√ß√£o
category_counts = Counter(class_to_category.values())
print("\nüìä Distribui√ß√£o por categoria:")
for cat, count in sorted(category_counts.items(), key=lambda x: -x[1]):
    print(f"   {cat}: {count} classes")

# Lista classes que foram para "other"
other_classes = [cls for cls, cat in class_to_category.items() if cat == 'other']
if other_classes:
    print(f"\n‚ö†Ô∏è Classes em 'other' ({len(other_classes)}):")
    for cls in other_classes[:20]:
        print(f"   - {cls}")
    if len(other_classes) > 20:
        print(f"   ... e mais {len(other_classes) - 20}")

# =========================================================
# 4. Criar estrutura YOLO e converter anota√ß√µes
# =========================================================
print("\nüìÅ Criando estrutura YOLO...")

SIMPLIFIED_NAMES = list(CATEGORY_MAPPING.keys()) + ['other']
category_to_id = {cat: idx for idx, cat in enumerate(SIMPLIFIED_NAMES)}

# Cria estrutura de pastas
for split in ['train', 'valid', 'test']:
    os.makedirs(f"dataset/{split}/images", exist_ok=True)
    os.makedirs(f"dataset/{split}/labels", exist_ok=True)

def convert_voc_to_yolo(xml_file, img_width, img_height):
    """Converte anota√ß√£o Pascal VOC para formato YOLO."""
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    yolo_lines = []
    for obj in root.findall('object'):
        class_name = obj.find('name').text
        category = class_to_category.get(class_name, 'other')
        class_id = category_to_id[category]
        
        bbox = obj.find('bndbox')
        xmin = float(bbox.find('xmin').text)
        ymin = float(bbox.find('ymin').text)
        xmax = float(bbox.find('xmax').text)
        ymax = float(bbox.find('ymax').text)
        
        # Converte para formato YOLO (centro x, centro y, width, height) normalizado
        x_center = (xmin + xmax) / 2 / img_width
        y_center = (ymin + ymax) / 2 / img_height
        width = (xmax - xmin) / img_width
        height = (ymax - ymin) / img_height
        
        # Garante valores entre 0 e 1
        x_center = max(0, min(1, x_center))
        y_center = max(0, min(1, y_center))
        width = max(0, min(1, width))
        height = max(0, min(1, height))
        
        yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
    
    return yolo_lines

# Processa todos os arquivos
from PIL import Image
import random

# Coleta pares (imagem, xml) - busca recursiva
pairs = []
for xml_file in xml_files:
    img_name = xml_file.stem
    xml_dir = xml_file.parent
    for ext in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']:
        img_path = xml_dir / f"{img_name}{ext}"
        if img_path.exists():
            pairs.append((img_path, xml_file))
            break

print(f"   Encontrados {len(pairs)} pares imagem/anota√ß√£o")

if len(pairs) == 0:
    print("‚ùå Nenhum par imagem/anota√ß√£o encontrado!")
    raise Exception("Dataset inv√°lido")

# Shuffle e split (80% train, 10% valid, 10% test)
random.seed(42)
random.shuffle(pairs)

n_train = int(len(pairs) * 0.8)
n_valid = int(len(pairs) * 0.1)

train_pairs = pairs[:n_train]
valid_pairs = pairs[n_train:n_train + n_valid]
test_pairs = pairs[n_train + n_valid:]

print(f"   Split: {len(train_pairs)} train, {len(valid_pairs)} valid, {len(test_pairs)} test")

# Converte e copia
label_counts = Counter()
errors = 0

for split, split_pairs in [('train', train_pairs), ('valid', valid_pairs), ('test', test_pairs)]:
    for img_path, xml_path in split_pairs:
        try:
            # L√™ dimens√µes da imagem
            with Image.open(img_path) as img:
                img_width, img_height = img.size
            
            # Converte anota√ß√£o
            yolo_lines = convert_voc_to_yolo(xml_path, img_width, img_height)
            
            if yolo_lines:
                # Conta labels por categoria
                for line in yolo_lines:
                    class_id = int(line.split()[0])
                    label_counts[SIMPLIFIED_NAMES[class_id]] += 1
                
                # Copia imagem
                dest_img = Path(f"dataset/{split}/images") / img_path.name
                shutil.copy(img_path, dest_img)
                
                # Salva label YOLO
                label_name = img_path.stem + ".txt"
                dest_label = Path(f"dataset/{split}/labels") / label_name
                with open(dest_label, "w") as f:
                    f.write("\n".join(yolo_lines))
        except Exception as e:
            errors += 1

if errors > 0:
    print(f"   ‚ö†Ô∏è {errors} arquivos com erro (ignorados)")

print("\nüìä Distribui√ß√£o de labels por categoria:")
total_labels = sum(label_counts.values())
for cat in SIMPLIFIED_NAMES:
    count = label_counts.get(cat, 0)
    pct = (count / total_labels * 100) if total_labels > 0 else 0
    bar = "#" * min(50, count // 50)
    print(f"   {cat:12}: {count:5} ({pct:5.1f}%) {bar}")

print(f"\n   Total: {total_labels} labels")

# =========================================================
# 5. Criar data.yaml
# =========================================================
data_config = {
    'path': '/content/yolo-project/dataset',
    'train': 'train/images',
    'val': 'valid/images',
    'test': 'test/images',
    'nc': len(SIMPLIFIED_NAMES),
    'names': SIMPLIFIED_NAMES
}

with open("dataset/data.yaml", "w") as f:
    yaml.dump(data_config, f, default_flow_style=False)

print(f"\n‚úÖ Dataset preparado com {len(SIMPLIFIED_NAMES)} categorias!")

# =========================================================
# 6. Verificar Checkpoints Anteriores (RETOMAR TREINAMENTO)
# =========================================================
print("\nüîç Verificando checkpoints anteriores...")

# Procura checkpoint mais recente para retomar
checkpoint_files = sorted(glob.glob(f"{DRIVE_CHECKPOINTS}/epoch*.pt"))
last_checkpoint = None
resume_training = False

if os.path.exists(f"{DRIVE_CHECKPOINTS}/last.pt"):
    last_checkpoint = f"{DRIVE_CHECKPOINTS}/last.pt"
    resume_training = True
    print(f"‚úÖ Checkpoint encontrado: last.pt")
    print("   üîÑ Treinamento ser√° RETOMADO do √∫ltimo checkpoint!")
elif checkpoint_files:
    last_checkpoint = checkpoint_files[-1]
    resume_training = True
    print(f"‚úÖ Checkpoint encontrado: {os.path.basename(last_checkpoint)}")
    print("   üîÑ Treinamento ser√° RETOMADO deste checkpoint!")
else:
    print("üì≠ Nenhum checkpoint encontrado - iniciando do zero")

# =========================================================
# 7. Callback para salvar checkpoints
# =========================================================
SAVE_EVERY_N_EPOCHS = 5

def save_checkpoint_to_drive(trainer):
    """Salva checkpoints no Google Drive."""
    current_epoch = trainer.epoch + 1
    
    if current_epoch % SAVE_EVERY_N_EPOCHS == 0:
        weights_dir = trainer.save_dir / "weights"
        
        if (weights_dir / "last.pt").exists():
            epoch_name = f"epoch_{current_epoch:03d}.pt"
            shutil.copy(weights_dir / "last.pt", f"{DRIVE_CHECKPOINTS}/{epoch_name}")
            shutil.copy(weights_dir / "last.pt", f"{DRIVE_CHECKPOINTS}/last.pt")
            print(f"\nüíæ Checkpoint salvo: {epoch_name}")
        
        if (weights_dir / "best.pt").exists():
            shutil.copy(weights_dir / "best.pt", f"{DRIVE_CHECKPOINTS}/best.pt")

# =========================================================
# 8. Carregar modelo (do checkpoint ou base)
# =========================================================
if resume_training and last_checkpoint:
    print(f"\nüì¶ Carregando checkpoint: {last_checkpoint}")
    model = YOLO(last_checkpoint)
else:
    print("\nüì¶ Carregando modelo base: yolov8n.pt")
    model = YOLO("yolov8n.pt")

model.add_callback("on_train_epoch_end", save_checkpoint_to_drive)

# =========================================================
# 9. Treinamento
# =========================================================
print("\nüöÄ Iniciando treinamento...")
print(f"   üìä {len(train_pairs)} imagens de treino")
print(f"   üìä {len(SIMPLIFIED_NAMES)} categorias")
if resume_training:
    print("   üîÑ RETOMANDO do checkpoint anterior!")

results = model.train(
    data="dataset/data.yaml",
    
    # Configura√ß√£o principal
    epochs=150,
    patience=30,
    batch=16,
    imgsz=640,
    
    # Retomar treinamento
    resume=resume_training,
    
    # Otimiza√ß√£o
    optimizer='AdamW',
    lr0=0.001,
    lrf=0.01,
    weight_decay=0.0005,
    warmup_epochs=3,
    cos_lr=True,
    
    # Augmenta√ß√£o (moderada - dataset j√° tem augmenta√ß√£o)
    hsv_h=0.015,
    hsv_s=0.4,
    hsv_v=0.3,
    degrees=10,
    translate=0.1,
    scale=0.4,
    fliplr=0.5,
    mosaic=0.8,
    mixup=0.1,
    
    # Loss weights
    cls=1.0,
    box=7.5,
    dfl=1.5,
    
    # Infraestrutura
    cache=True,
    workers=4,
    device=0,
    exist_ok=True,
    plots=True,
    save_period=5,
    
    name='train_kaggle',
    project='runs/detect',
)

# =========================================================
# 10. Salvar modelo final
# =========================================================
DEST_WEIGHTS = f"{DRIVE_PROJECT}/weights_backup"
SOURCE_WEIGHTS = "runs/detect/train_kaggle/weights"

os.makedirs(DEST_WEIGHTS, exist_ok=True)

print("\nüíæ Salvando modelo final...")

if os.path.exists(f"{SOURCE_WEIGHTS}/best.pt"):
    shutil.copy(f"{SOURCE_WEIGHTS}/best.pt", f"{DEST_WEIGHTS}/best_kaggle.pt")
    shutil.copy(f"{SOURCE_WEIGHTS}/best.pt", f"{DRIVE_CHECKPOINTS}/best_final.pt")
    
    with open(f"{DEST_WEIGHTS}/class_mapping_kaggle.yaml", "w") as f:
        yaml.dump({
            'simplified_names': SIMPLIFIED_NAMES,
            'category_mapping': CATEGORY_MAPPING,
            'original_classes': list(all_classes)
        }, f)
    
    print(f"‚úÖ Modelo salvo: {DEST_WEIGHTS}/best_kaggle.pt")
else:
    print("‚ö†Ô∏è best.pt n√£o encontrado")

# Limpa checkpoints antigos (mant√©m apenas √∫ltimos 5)
checkpoint_files = sorted(glob.glob(f"{DRIVE_CHECKPOINTS}/epoch*.pt"))
if len(checkpoint_files) > 5:
    for old_ckpt in checkpoint_files[:-5]:
        os.remove(old_ckpt)
        
print("\n" + "="*50)
print("‚úÖ TREINAMENTO CONCLU√çDO!")
print("="*50)
print(f"\nüìÅ Modelo final: {DEST_WEIGHTS}/best_kaggle.pt")
print("   Baixe esse arquivo e coloque em models/best.pt no seu PC")

In [None]:
# =========================================================
# 11. VALIDA√á√ÉO DO MODELO TREINADO
# =========================================================
from ultralytics import YOLO
import matplotlib.pyplot as plt
import random
from pathlib import Path
import os

# Configura√ß√£o
DRIVE_ROOT = "/content/drive/MyDrive/colab"
PROJECT_NAME = "cloud-arch-security-mvp"
DRIVE_PROJECT = f"{DRIVE_ROOT}/{PROJECT_NAME}"
DRIVE_CHECKPOINTS = f"{DRIVE_PROJECT}/checkpoints"

# Monta Drive se necess√°rio
from google.colab import drive
if not os.path.exists('/content/drive/MyDrive'):
    drive.mount('/content/drive')

# Procura o modelo
print("\nüîç Procurando modelo treinado...")

model_paths = [
    "runs/detect/train_kaggle/weights/best.pt",
    f"{DRIVE_CHECKPOINTS}/best.pt",
    f"{DRIVE_CHECKPOINTS}/best_final.pt",
    f"{DRIVE_PROJECT}/weights_backup/best_kaggle.pt",
]

best_model_path = None
for path in model_paths:
    if os.path.exists(path):
        best_model_path = path
        print(f"‚úÖ Modelo encontrado: {path}")
        break

if not best_model_path:
    print("‚ùå Nenhum modelo encontrado!")
else:
    val_model = YOLO(best_model_path)
    print(f"üìä Modelo tem {len(val_model.names)} classes:")
    for idx, name in val_model.names.items():
        print(f"   {idx}: {name}")
    
    # Valida√ß√£o
    print("\nüß™ Validando modelo...")
    
    val_results = val_model.val(
        data="dataset/data.yaml",
        split="test",
        plots=True,
        save_json=True
    )
    
    print("\nüìä M√âTRICAS DE VALIDA√á√ÉO:")
    print(f"   mAP50: {val_results.box.map50:.4f}")
    print(f"   mAP50-95: {val_results.box.map:.4f}")
    print(f"   Precis√£o: {val_results.box.mp:.4f}")
    print(f"   Recall: {val_results.box.mr:.4f}")
    
    print("\nüìà mAP50 por categoria:")
    for i, name in enumerate(val_model.names.values()):
        if i < len(val_results.box.ap50):
            ap = val_results.box.ap50[i]
            print(f"   {name}: {ap:.4f}")

    # Teste visual
    print("\nüñºÔ∏è Testando em imagem de exemplo...")
    
    test_images = list(Path("dataset/test/images").glob("*.png")) + \
                  list(Path("dataset/test/images").glob("*.jpg"))
    
    if test_images:
        # Testa 3 imagens aleat√≥rias
        for test_img in random.sample(test_images, min(3, len(test_images))):
            print(f"\n   üì∑ {test_img.name}")
            
            results = val_model(str(test_img), conf=0.25, verbose=False)
            
            detected = set()
            for r in results:
                for box in r.boxes:
                    cls_name = val_model.names[int(box.cls[0])]
                    conf = float(box.conf[0])
                    detected.add(f"{cls_name} ({conf:.2f})")
            
            if detected:
                print(f"      Detectado: {', '.join(detected)}")
            else:
                print("      ‚ö†Ô∏è Nenhuma detec√ß√£o")
            
            # Exibe
            result_img = results[0].plot()
            plt.figure(figsize=(12, 8))
            plt.imshow(result_img)
            plt.axis('off')
            plt.title(f"Detec√ß√µes em {test_img.name}")
            plt.show()
    else:
        print("   ‚ö†Ô∏è Nenhuma imagem de teste encontrada")