# üöÄ RoofVision AI - Training on Cloud GPU

Use this notebook to train your model on **Google Colab** or **Kaggle** for free using a T4 GPU.

### Steps:
1.  **Run All Cells** from top to bottom.
2.  Wait for training to finish (approx 10-15 mins for 10 epochs).
3.  **Download** the `antigravity_model.pt` file at the end.
4.  Upload it to your laptop's `solar 3/antigravity/backend/models/` folder.

In [None]:
# 1. Install Dependencies
!pip install torch torchvision opencv-python-headless matplotlib pyyaml tqdm
import torch
print(f"Using GPU: {torch.cuda.get_device_name(0)}")

### 2. Upload Your Dataset
1. Click the **Files** icon on the left sidebar.
2. Drag and drop your **`dataset_train`** and **`dataset_valid`** folders here.
   * Tip: Zip them first (`dataset.zip`), upload, and run `!unzip dataset.zip` to be faster.

In [None]:
# (Optional) Unzip if you uploaded a zip file
# !unzip dataset.zip

In [None]:
# 3. Define the Model & Training Script

import os
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from tqdm import tqdm

def get_model_instance_segmentation(num_classes):
    # Load pre-trained Mask R-CNN
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # Replace box predictor
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace mask predictor
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model

# --- IMPROVED DATASET CLASS (Supports Boxes & Polygons) ---
class SolarDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        self.labels = [f.replace(".jpg", ".txt").replace(".png", ".txt") for f in self.imgs]

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        label_path = os.path.join(self.root, "labels", self.labels[idx])

        # 1. Load Image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 2. Normalize (0-1) - CRITICAL STEP
        img = img.astype(np.float32) / 255.0

        height, width, _ = img.shape

        boxes, masks_list, labels = [], [], []

        if os.path.exists(label_path):
            with open(label_path) as f:
                for line in f.readlines():
                    parts = list(map(float, line.strip().split()))
                    coords = parts[1:]

                    poly_points = []

                    # --- DETECTION LOGIC (Handling Box Format) ---
                    if len(coords) == 4:
                        # Format: cx, cy, w, h (Normalized)
                        cx, cy, w_box, h_box = coords

                        # Convert to corners
                        x1 = int((cx - w_box/2) * width)
                        y1 = int((cy - h_box/2) * height)
                        x2 = int((cx + w_box/2) * width)
                        y2 = int((cy + h_box/2) * height)

                        # Create 4 points for the rectangle polygon
                        poly_points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]

                    # --- SEGMENTATION LOGIC (Handling Polygon Format) ---
                    else:
                        for i in range(0, len(coords), 2):
                            x = int(coords[i] * width)
                            y = int(coords[i+1] * height)
                            poly_points.append([x, y])

                    if len(poly_points) < 3: continue

                    # Create Mask
                    mask = np.zeros((height, width), dtype=np.uint8)
                    cv2.fillPoly(mask, [np.array(poly_points)], 1)
                    masks_list.append(mask)

                    # Create Box
                    x_coords = [p[0] for p in poly_points]
                    y_coords = [p[1] for p in poly_points]
                    xmin, xmax = min(x_coords), max(x_coords)
                    ymin, ymax = min(y_coords), max(y_coords)

                    # Validate box area
                    if xmax > xmin and ymax > ymin:
                        boxes.append([xmin, ymin, xmax, ymax])
                        labels.append(1) # Class 1 = Solar Panel

        target = {}
        if len(boxes) > 0:
            target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
            target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
            target["masks"] = torch.as_tensor(np.array(masks_list), dtype=torch.uint8)
        else:
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros((0,), dtype=torch.int64)
            target["masks"] = torch.zeros((0, height, width), dtype=torch.uint8)

        target["image_id"] = torch.tensor([idx])

        # Convert Image to Tensor (CHW format)
        img = torch.as_tensor(img.transpose((2, 0, 1)))

        return img, target

    def __len__(self):
        return len(self.imgs)

def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
# 4. Run Training (NUCLEAR OPTION - 100% STABLE)

import os
import torch
import torchvision
import cv2
import numpy as np
import gc
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# ==========================================
# 1. SETUP DATASET & MODEL
# ==========================================

def get_model_instance_segmentation(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

class SolarDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        self.labels = [f.replace(".jpg", ".txt").replace(".png", ".txt") for f in self.imgs]

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        label_path = os.path.join(self.root, "labels", self.labels[idx])

        # --- AGGRESSIVE RESIZE (Max 480px) ---
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Calculate resize scale
        h, w = img.shape[:2]
        target_size = 480
        scale = target_size / max(h, w)
        if scale < 1.0:
            img = cv2.resize(img, (0,0), fx=scale, fy=scale)

        # Normalize
        img = img.astype(np.float32) / 255.0
        height, width, _ = img.shape # New dimensions

        boxes, masks_list, labels = [], [], []

        if os.path.exists(label_path):
            with open(label_path) as f:
                for line in f.readlines():
                    try:
                        parts = list(map(float, line.strip().split()))
                        coords = parts[1:]
                        poly_points = []

                        if len(coords) == 4: # Box
                            cx, cy, w_box, h_box = coords
                            x1 = int((cx - w_box/2) * width)
                            y1 = int((cy - h_box/2) * height)
                            x2 = int((cx + w_box/2) * width)
                            y2 = int((cy + h_box/2) * height)
                            poly_points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
                        else: # Polygon
                            for i in range(0, len(coords), 2):
                                poly_points.append([int(coords[i] * width), int(coords[i+1] * height)])

                        if len(poly_points) < 3: continue

                        mask = np.zeros((height, width), dtype=np.uint8)
                        cv2.fillPoly(mask, [np.array(poly_points)], 1)
                        masks_list.append(mask)

                        x_coords = [p[0] for p in poly_points]
                        y_coords = [p[1] for p in poly_points]
                        boxes.append([min(x_coords), min(y_coords), max(x_coords), max(y_coords)])
                        labels.append(1)
                    except: continue

        target = {}
        if len(boxes) > 0:
            target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
            target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
            target["masks"] = torch.as_tensor(np.array(masks_list), dtype=torch.uint8)
        else:
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros((0,), dtype=torch.int64)
            target["masks"] = torch.zeros((0, height, width), dtype=torch.uint8)

        target["image_id"] = torch.tensor([idx])
        img = torch.as_tensor(img.transpose((2, 0, 1)))
        return img, target

    def __len__(self):
        return len(self.imgs)

def collate_fn(batch):
    return tuple(zip(*batch))

def get_batch_stats(predictions, targets, iou_threshold=0.5):
    tp, fp, fn, matched_iou, matched_instances = 0, 0, 0, 0, 0
    for img_preds, img_targets in zip(predictions, targets):
        keep = img_preds['scores'] > 0.05
        pred_scores = img_preds['scores'][keep]
        pred_masks = img_preds['masks'][keep]
        gt_masks = img_targets['masks']

        if len(gt_masks) == 0 and len(pred_masks) == 0: continue
        if len(gt_masks) == 0: fp += len(pred_masks); continue
        if len(pred_masks) == 0: fn += len(gt_masks); continue

        if pred_masks.dim() == 4: pred_masks = pred_masks.squeeze(1)
        if gt_masks.dim() == 4: gt_masks = gt_masks.squeeze(1)

        iou_matrix = ((pred_masks > 0.5)[:, None] & gt_masks.bool()[None, :]).sum((2,3)).float() / \
                     torch.clamp(((pred_masks > 0.5)[:, None] | gt_masks.bool()[None, :]).sum((2,3)).float(), min=1e-6)

        pred_matched = torch.zeros(len(pred_masks), dtype=torch.bool)
        gt_matched = torch.zeros(len(gt_masks), dtype=torch.bool)

        for pred_idx in torch.argsort(pred_scores, descending=True):
            if pred_matched[pred_idx]: continue
            best_iou, best_gt = 0.0, -1
            for gt_idx in range(len(gt_masks)):
                if gt_matched[gt_idx]: continue
                if iou_matrix[pred_idx, gt_idx] > best_iou: best_iou, best_gt = iou_matrix[pred_idx, gt_idx], gt_idx
            if best_iou >= iou_threshold and best_gt != -1:
                tp += 1; matched_iou += best_iou.item(); matched_instances += 1
                pred_matched[pred_idx] = True; gt_matched[best_gt] = True
            else: fp += 1
        fn += (gt_matched == False).sum().item()
    return tp, fp, fn, matched_iou, matched_instances

# ==========================================
# 2. RUN TRAINING
# ==========================================

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
gc.collect(); torch.cuda.empty_cache()

train_dir, valid_dir = "./dataset_train/", "./dataset_valid/"

if not os.path.exists(train_dir):
    print(f"‚ùå Error: {train_dir} not found! Did you unzip?")
else:
    print(f"‚úÖ Found Dataset at {train_dir}")
    print("‚ö° Training with BATCH_SIZE=1 and RESIZE=480px (Safe Mode)")

    dataset = SolarDataset(train_dir)
    valid_dataset = SolarDataset(valid_dir)

    # BATCH SIZE 1 (Very safe)
    train_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    model = get_model_instance_segmentation(2).to(device)
    optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=0.005, momentum=0.9, weight_decay=0.0005)

    num_epochs = 10
    print("üöÄ Starting Training...")

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Train")

        for images, targets in pbar:
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            total_loss += losses.item()
            pbar.set_postfix({'loss': losses.item()})
            del images, targets, loss_dict, losses # CLEAN

        # VALIDATION (Only first 40 images to check health)
        model.eval()
        epoch_tp, epoch_fp, epoch_fn, epoch_iou_sum, epoch_iou_count = 0, 0, 0, 0, 0
        val_count = 0

        with torch.no_grad():
            for images, targets in tqdm(valid_loader, desc=f"Epoch {epoch+1} Valid"):
                if val_count > 40: break # STOP validation early to create checkpoints faster
                val_count += 1

                images = list(img.to(device) for img in images)
                preds = model(images)

                # Move to CPU immediately
                preds_cpu = [{k: v.to('cpu') for k, v in p.items()} for p in preds]
                targets_cpu = [{k: v.to('cpu') for k, v in t.items()} for t in targets]

                tp, fp, fn, miou, minnst = get_batch_stats(preds_cpu, targets_cpu)
                epoch_tp += tp; epoch_fp += fp; epoch_fn += fn
                epoch_iou_sum += miou; epoch_iou_count += minnst
                del images, targets, preds, preds_cpu, targets_cpu

        precision = epoch_tp / (epoch_tp + epoch_fp) if (epoch_tp + epoch_fp) > 0 else 0.0
        recall = epoch_tp / (epoch_tp + epoch_fn) if (epoch_tp + epoch_fn) > 0 else 0.0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        final_iou = epoch_iou_sum / epoch_iou_count if epoch_iou_count > 0 else 0.0

        print(f"   Train Loss: {total_loss/len(train_loader):.4f}")
        print(f"   Precision:  {precision:.4f} | Recall: {recall:.4f}")
        print(f"   F1 Score:   {f1:.4f}      | IoU:    {final_iou:.4f}")

    torch.save(model.state_dict(), "antigravity_model.pt")
    print("\nüèÅ DONE! Run next cell to download.")

In [None]:
# 5. Download the Brain
from google.colab import files
files.download('antigravity_model.pt')

In [None]:
# DIAGNOSTIC SCRIPT
dataset = SolarDataset("./dataset_train/")
print(f"Total dataset size: {len(dataset)}")

# Check first 5 images
for i in range(5):
    img, target = dataset[i]
    num_boxes = len(target["boxes"])
    print(f"Image {i}: Found {num_boxes} solar panels")

    if num_boxes == 0:
        # Check if label file actually exists
        img_filename = dataset.imgs[i]
        label_filename = dataset.labels[i]
        expected_path = os.path.join(dataset.root, "labels", label_filename)
        print(f"   ‚ùå WARNING: No labels loaded! Checked path: {expected_path}")
        print(f"      File exists? {os.path.exists(expected_path)}")