In [6]:
import os
import math
import random
import glob
from typing import Tuple, List
from tqdm import tqdm

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.io import read_image
from torchvision.ops import box_convert
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

print('Libraries imported.')

# ===============================
# Configuration / Hyperparameters
# ===============================
DATA_DIR = "/kaggle/input/strawdi/data/StrawDI_Db1"  # Path to dataset
VIDEO_PATH = "/kaggle/input/strawdi/data/test.mp4"   # Path to video
TRAIN_EPOCHS = 1
LR = 1e-3
BATCH_SIZE = 4
DEVICE = "cuda"  # or "cpu"


# ---------------------------------------------------------
# 1. Dataset Definition
# ---------------------------------------------------------

class StrawberriesDataset(Dataset):
    """
    A custom dataset for Strawberry detection+segmentation.

    Assumes structure:
      root/
         train/img/*.png
         train/label/*.png
         val/img/*.png
         val/label/*.png
         test/img/*.png
         test/label/*.png

    Each label image has integer IDs marking strawberries (0 is background).
    """

    def __init__(self, root: str, split: str = "train", transform=None):
        super().__init__()
        self.root = root
        self.split = split
        self.transform = transform

        img_dir = os.path.join(root, split, "img")
        mask_dir = os.path.join(root, split, "label")

        self.img_paths = sorted(glob.glob(os.path.join(img_dir, "*")))
        self.label_paths = sorted(glob.glob(os.path.join(mask_dir, "*")))

        assert len(self.img_paths) == len(self.label_paths), \
            "Mismatch between image and mask count."

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        mask_path = self.label_paths[idx]

        # Read with torchvision
        img = read_image(img_path).float() / 255.0  # shape: [C,H,W], range: [0,1]
        mask = read_image(mask_path)                # shape: [1,H,W] with integer IDs

        # Convert from 1-channel integer ID to multi-class (BG + 1 class for strawberry)
        is_straw = mask[0] > 0  # shape [H,W]
        # semantic = 2 channels: [bg, strawberry]
        semantic = torch.stack([~is_straw, is_straw], dim=0).float()

        # Build bounding boxes
        unique_ids = torch.unique(mask[0])  # e.g., [0,1,2,...]
        unique_ids = unique_ids[unique_ids > 0]  # exclude background

        if len(unique_ids) == 0:
            # No strawberries
            boxes = torch.zeros((0, 6))  # shape [N,6] => [img_idx, cls, cx,cy,w,h]
        else:
            # Build multi-channel mask per ID:
            all_obj_masks = torch.stack([(mask[0] == u) for u in unique_ids], dim=0)
            # Convert to boxes
            raw_boxes_xyxy = torchvision.ops.masks_to_boxes(all_obj_masks)
            # XYXY -> CXCYWH
            raw_boxes_cxcywh = box_convert(raw_boxes_xyxy, in_fmt="xyxy", out_fmt="cxcywh")
            # Normalize
            H, W = img.shape[1], img.shape[2]
            raw_boxes_cxcywh[:, [0]] /= float(W)
            raw_boxes_cxcywh[:, [1]] /= float(H)
            raw_boxes_cxcywh[:, [2]] /= float(W)
            raw_boxes_cxcywh[:, [3]] /= float(H)

            box_num = raw_boxes_cxcywh.shape[0]
            boxes = torch.zeros((box_num, 6))
            # class=0 for strawberry in all rows
            boxes[:, 1] = 0.0
            boxes[:, 2:] = raw_boxes_cxcywh

        if self.transform is not None:
            img, semantic = self.transform(img, semantic)

        return img, (semantic, boxes)

    @staticmethod
    def collate_fn(batch):
        """
        Custom collate to handle variable #boxes.
        Output:
          images: [B,3,H,W]
          segs:   [B,2,H,W]
          boxes:  [sum_all_boxes, 6]
        """
        imgs = []
        segs = []
        all_boxes = []

        for i, (img, (seg, boxes)) in enumerate(batch):
            imgs.append(img)
            segs.append(seg)

            # Fill in "image index" for each box
            if len(boxes) > 0:
                boxes_with_idx = boxes.clone()
                boxes_with_idx[:, 0] = i
                all_boxes.append(boxes_with_idx)

        # Stack
        imgs_t = torch.stack(imgs, dim=0)
        segs_t = torch.stack(segs, dim=0)

        if len(all_boxes) > 0:
            boxes_t = torch.cat(all_boxes, dim=0)
        else:
            # no boxes at all
            boxes_t = torch.zeros((0, 6))

        return imgs_t, (segs_t, boxes_t)


# ---------------------------------------------------------
# 2. Simple Augmentations
# ---------------------------------------------------------

class ResizeAndPad:
    """
    Resize the image & label to a fixed shape using bilinear (for image) 
    and nearest (for mask).
    """
    def __init__(self, out_h=480, out_w=640):
        self.out_h = out_h
        self.out_w = out_w

    def __call__(self, img, seg):
        img_r = torchvision.transforms.functional.resize(
            img, (self.out_h, self.out_w), antialias=True
        )
        seg_r = torchvision.transforms.functional.resize(
            seg, (self.out_h, self.out_w), torchvision.transforms.InterpolationMode.NEAREST
        )
        return img_r, seg_r


# ---------------------------------------------------------
# 3. Model
# ---------------------------------------------------------

class SimpleBackbone(nn.Module):
    def __init__(self, in_ch=3, out_ch=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.act = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32, out_ch, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.act(self.bn1(self.conv1(x)))
        x = self.act(self.bn2(self.conv2(x)))
        return x  # [B, out_ch, H/4, W/4]


class SegHead(nn.Module):
    """Up-conv to produce 2-channel seg (bg, straw)."""
    def __init__(self, in_ch=64, out_ch=2):
        super().__init__()
        self.up1 = nn.ConvTranspose2d(in_ch, in_ch//2, kernel_size=2, stride=2)
        self.bn1 = nn.BatchNorm2d(in_ch//2)
        self.act = nn.ReLU(inplace=True)

        self.up2 = nn.ConvTranspose2d(in_ch//2, out_ch, kernel_size=2, stride=2)

    def forward(self, x):
        x = self.act(self.bn1(self.up1(x)))
        x = self.up2(x)
        return x  # [B,2,H,W]


class DetectHead(nn.Module):
    """Naive detection head -> [B, A*(5+num_classes), H/4, W/4]."""
    def __init__(self, in_ch=64, num_anchors=3, num_classes=1):
        super().__init__()
        self.num_anchors = num_anchors
        self.num_classes = num_classes
        self.out_filters = 5 + num_classes  # cx,cy,w,h,obj + class

        self.conv = nn.Conv2d(
            in_ch, self.num_anchors * self.out_filters,
            kernel_size=1, stride=1, padding=0
        )

    def forward(self, x):
        return self.conv(x)


class StrawModel(nn.Module):
    """Shared backbone => segmentation head => detection head."""
    def __init__(self, num_classes=1):
        super().__init__()
        self.anchors = torch.tensor([[10,13],[16,30],[33,50]], dtype=torch.float32)
        self.backbone = SimpleBackbone(3, 64)
        self.seg_head = SegHead(in_ch=64, out_ch=2)
        self.det_head = DetectHead(in_ch=64, num_anchors=3, num_classes=num_classes)

    def forward(self, x):
        feat = self.backbone(x)
        seg_out = self.seg_head(feat)
        det_out = self.det_head(feat)
        return seg_out, det_out


# ---------------------------------------------------------
# 4. Losses
# ---------------------------------------------------------

def segmentation_loss(pred_seg, true_seg):
    bce = nn.BCEWithLogitsLoss()
    loss_seg = bce(pred_seg, true_seg)

    # Simple IoU for logging
    with torch.no_grad():
        pred_label = torch.argmax(pred_seg, dim=1)  # [B,H,W], 0 or 1
        true_label = torch.argmax(true_seg, dim=1)
        inter = torch.logical_and(pred_label == 1, true_label == 1).sum().item()
        union = (torch.logical_or(pred_label == 1, true_label == 1)).sum().item()
        iou = inter / (union + 1e-6) if union > 0 else 1.0

    return loss_seg, iou


def detection_loss(pred_det, true_boxes, device="cpu"):
    """
    Toy YOLO-like detection loss for demonstration, not a full YOLO.
    """
    bce = nn.BCEWithLogitsLoss()
    mse = nn.MSELoss()

    B, _, H, W = pred_det.shape
    A = 3
    out_ch = 5 + 1
    pred_det = pred_det.view(B, A, out_ch, H, W)

    obj_target = torch.zeros_like(pred_det[:, :, 4])  # shape [B,A,H,W]
    box_target = torch.zeros_like(pred_det[:, :, 0:4])# shape [B,A,4,H,W]
    cls_target = torch.zeros_like(pred_det[:, :, 5:]) # shape [B,A,1,H,W]

    # For each GT box, place it in the nearest cell (cx,cy).
    for i in range(true_boxes.shape[0]):
        b_idx = int(true_boxes[i, 0].item())
        cx, cy, w_b, h_b = true_boxes[i, 2:].tolist()

        gx = int(cx * W)
        gy = int(cy * H)
        if not (0 <= gx < W and 0 <= gy < H):
            continue

        a_idx = 0  # always anchor 0
        obj_target[b_idx, a_idx, gy, gx] = 1.0
        box_target[b_idx, a_idx, 0, gy, gx] = cx
        box_target[b_idx, a_idx, 1, gy, gx] = cy
        box_target[b_idx, a_idx, 2, gy, gx] = w_b
        box_target[b_idx, a_idx, 3, gy, gx] = h_b
        cls_target[b_idx, a_idx, 0, gy, gx] = 1.0

    pred_box = pred_det[:, :, 0:4]
    pred_obj = pred_det[:, :, 4]
    pred_cls = pred_det[:, :, 5:]

    l_obj = bce(pred_obj, obj_target)
    l_box = mse(pred_box, box_target)
    l_cls = bce(pred_cls, cls_target)

    return l_obj + 0.5*l_box + l_cls


# ---------------------------------------------------------
# 5. Train / Evaluate
# ---------------------------------------------------------

def train_model(model, train_loader, val_loader, epochs=2, lr=1e-3, device="cpu"):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        total_iou = 0.0
        count = 0

        for imgs, (sems, boxes) in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
            imgs, sems, boxes = imgs.to(device), sems.to(device), boxes.to(device)

            optimizer.zero_grad()
            seg_out, det_out = model(imgs)

            # segmentation
            l_seg, seg_iou = segmentation_loss(seg_out, sems)
            # detection
            l_det = detection_loss(det_out, boxes, device=device)
            # combined
            loss = l_seg + l_det

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_iou += seg_iou
            count += 1

        avg_loss = total_loss / count
        avg_iou = total_iou / count

        val_loss, val_iou = evaluate(model, val_loader, device=device)
        print(f"Epoch [{epoch}/{epochs}] "
              f"Train Loss: {avg_loss:.3f}, SegIOU: {avg_iou:.3f} | "
              f"Val Loss: {val_loss:.3f}, ValSegIOU: {val_iou:.3f}")


@torch.no_grad()
def evaluate(model, loader, device="cpu"):
    model.eval()
    model.to(device)
    total_loss = 0.0
    total_iou = 0.0
    count = 0

    for imgs, (sems, boxes) in tqdm(loader, desc="Evaluating"):
        imgs, sems, boxes = imgs.to(device), sems.to(device), boxes.to(device)
        seg_out, det_out = model(imgs)

        l_seg, seg_iou = segmentation_loss(seg_out, sems)
        l_det = detection_loss(det_out, boxes, device=device)
        loss = l_seg + l_det

        total_loss += loss.item()
        total_iou += seg_iou
        count += 1

    avg_loss = total_loss / count if count else 0
    avg_iou = total_iou / count if count else 0
    return avg_loss, avg_iou


# ---------------------------------------------------------
# 6. Inference & Counting on Video
# ---------------------------------------------------------

@torch.no_grad()
def detect_on_frame(model, frame: torch.Tensor, device="cpu", conf_thresh=0.5):
    model.eval()
    inp = frame.permute(2, 0, 1).unsqueeze(0).to(device)
    seg_out, det_out = model(inp)

    seg_pred = seg_out[0]  # [2,H,W]
    seg_label = (seg_pred[1] > seg_pred[0]).cpu()

    B, A_times, H_det, W_det = det_out.shape
    A = 3
    out_ch = 5 + 1
    det_out = det_out.view(B, A, out_ch, H_det, W_det)

    boxes_list = []

    for a_idx in range(A):
        for gy in range(H_det):
            for gx in range(W_det):
                obj_logit = det_out[0, a_idx, 4, gy, gx].item()
                obj_prob = torch.sigmoid(torch.tensor(obj_logit)).item()
                if obj_prob < conf_thresh:
                    continue
                cxyw = det_out[0, a_idx, 0:4, gy, gx]
                cx = cxyw[0].item() * frame.shape[1]
                cy = cxyw[1].item() * frame.shape[0]
                bw = cxyw[2].item() * frame.shape[1]
                bh = cxyw[3].item() * frame.shape[0]

                x1 = max(cx - bw/2, 0)
                y1 = max(cy - bh/2, 0)
                x2 = min(cx + bw/2, frame.shape[1])
                y2 = min(cy + bh/2, frame.shape[0])

                boxes_list.append([x1, y1, x2, y2])

    if len(boxes_list) == 0:
        return np.zeros((0,4), dtype=np.float32), seg_label
    return np.array(boxes_list, dtype=np.float32), seg_label


def iou(b1, b2):
    x1 = max(b1[0], b2[0])
    y1 = max(b1[1], b2[1])
    x2 = min(b1[2], b2[2])
    y2 = min(b1[3], b2[3])
    inter_w = max(0, x2 - x1)
    inter_h = max(0, y2 - y1)
    inter = inter_w * inter_h
    area1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
    area2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
    union = area1 + area2 - inter + 1e-6
    return inter / union


def track_and_count(straws_in_frame, tracks, iou_thresh=0.5):
    updated_tracks = []
    used_ids = set()

    for sbox in straws_in_frame:
        best_i = 0
        best_id = None

        for t in tracks:
            t_id = t[0]
            t_box = t[1:]
            overlap = iou(sbox, t_box)
            if overlap > best_i:
                best_i = overlap
                best_id = t_id

        if best_i > iou_thresh and best_id is not None:
            updated_tracks.append([best_id, sbox[0], sbox[1], sbox[2], sbox[3]])
            used_ids.add(best_id)
        else:
            new_id = 1 if len(tracks) == 0 else max(t[0] for t in tracks) + 1
            updated_tracks.append([new_id, sbox[0], sbox[1], sbox[2], sbox[3]])
            used_ids.add(new_id)

    return updated_tracks


import cv2
import torch
import numpy as np

def process_video_and_count(
    model,
    video_path,
    device="cpu",
    conf_thresh=0.5,
    iou_thresh=0.5,
    out_video_path="output.mp4"
):
    """
    Open 'video_path' with OpenCV, run detection+tracking on each frame.
    - Writes an output video file to 'out_video_path', with bounding boxes & IDs drawn.
    - Prints how many unique strawberries (tracks) were observed.

    No use of cv2.imshow() or cv2.waitKey() to avoid kernel crashes in headless notebooks.
    """

    print(f"Starting video processing: {video_path}")
    
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Cannot open video {video_path}")
        return

    # Prepare video writer for output
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))

    print(f"Video properties: FPS={fps}, Width={width}, Height={height}")

    tracks = []  # each element: [track_id, x1, y1, x2, y2]
    unique_ids_seen = set()
    frame_count = 0

    while True:
        ret, frame_bgr = cap.read()
        if not ret:
            print("End of video reached.")
            break

        frame_count += 1
        print(f"Processing frame {frame_count}...")

        # Convert to torch tensor
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
        frame_float = frame_rgb.astype(np.float32) / 255.0
        frame_t = torch.from_numpy(frame_float)  # shape [H, W, 3]

        # Detect bounding boxes & seg mask
        boxes_np, seg_mask = detect_on_frame(
            model, frame_t, device=device, conf_thresh=conf_thresh
        )

        # Track them
        new_tracks = track_and_count(boxes_np, tracks, iou_thresh=iou_thresh)
        tracks = new_tracks

        # Update global set of IDs
        for t in tracks:
            unique_ids_seen.add(t[0])

        # Draw bounding boxes/IDs on the frame
        for t in tracks:
            tid, x1, y1, x2, y2 = t
            color = (0, 255, 0)  # green
            cv2.rectangle(frame_bgr, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
            cv2.putText(frame_bgr, f"ID:{tid}",
                        (int(x1), int(y1) - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

        # Write the annotated frame to the output video
        writer.write(frame_bgr)

    # Clean up
    cap.release()
    writer.release()

    total_strawberries = len(unique_ids_seen)
    print(f"Finished writing '{out_video_path}'.")
    print(f"Total number of unique strawberries tracked in video: {total_strawberries}")


# ---------------------------------------------------------
# 7. Single-Run Pipeline
#    => Train => Evaluate => Process video
# ---------------------------------------------------------

# 1) Prepare dataset + loaders
train_ds = StrawberriesDataset(DATA_DIR, "train", transform=ResizeAndPad(480,640))
val_ds   = StrawberriesDataset(DATA_DIR, "val",   transform=ResizeAndPad(480,640))
test_ds  = StrawberriesDataset(DATA_DIR, "test",  transform=ResizeAndPad(480,640))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=StrawberriesDataset.collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=StrawberriesDataset.collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=StrawberriesDataset.collate_fn)

print("Datasets and loaders ready.")

# 2) Build model
model = StrawModel(num_classes=1)

# 3) Train
# print("Starting training...")
# train_model(model, train_loader, val_loader, epochs=TRAIN_EPOCHS, lr=LR, device=DEVICE)
# torch.save(model.state_dict(), "straw_model.pth")

# Load trained model.
model.load_state_dict(torch.load("straw_model.pth",weights_only=False))
print("Training completed. Model saved to straw_model.pth")

# 4) Evaluate on test set
print("Evaluating on test set...")
val_loss, val_iou = evaluate(model, test_loader, device=DEVICE)
print(f"[Eval on test split] Loss: {val_loss:.3f}, SegIOU: {val_iou:.3f}")

# 5) Inference on video
print("Running inference on video for counting...")
process_video_and_count(model, VIDEO_PATH, device=DEVICE)
print("All steps (train -> eval -> video) done!")


Libraries imported.
Datasets and loaders ready.
Training completed. Model saved to straw_model.pth
Evaluating on test set...


Evaluating: 100%|██████████| 50/50 [00:10<00:00,  4.97it/s]


[Eval on test split] Loss: 0.046, SegIOU: 0.749
Running inference on video for counting...
Starting video processing: /kaggle/input/strawdi/data/test.mp4
Video properties: FPS=29.966359034709278, Width=1280, Height=720
Processing frame 1...
Processing frame 2...
Processing frame 3...
Processing frame 4...
Processing frame 5...
Processing frame 6...
Processing frame 7...
Processing frame 8...
Processing frame 9...
Processing frame 10...
Processing frame 11...
Processing frame 12...
Processing frame 13...
Processing frame 14...
Processing frame 15...
Processing frame 16...
Processing frame 17...
Processing frame 18...
Processing frame 19...
Processing frame 20...
Processing frame 21...
Processing frame 22...
Processing frame 23...
Processing frame 24...
Processing frame 25...
Processing frame 26...
Processing frame 27...
Processing frame 28...
Processing frame 29...
Processing frame 30...
Processing frame 31...
Processing frame 32...
Processing frame 33...
Processing frame 34...
Processi