In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
import yaml
from typing import Tuple, List, Dict
import albumentations as A

## Dataset


In [None]:
class CattleDataset:
    """
    A dataset loader for YOLO8 format cattle images.
    Implements modern TensorFlow data loading practices.
    """

    def __init__(
        self,
        data_dir: str,
        img_size: Tuple[int, int] = (640, 640),
        batch_size: int = 32,
        augment: bool = True,
    ):
        self.data_dir = Path(data_dir)
        self.img_size = img_size
        self.batch_size = batch_size

        self.load_yaml_config()
        self.augment = augment
        if augment:
            self.aug_pipeline = A.Compose(
                [
                    A.RandomBrightnessContrast(p=0.5),
                    A.HorizontalFlip(p=0.5),
                    A.RandomRotate90(p=0.5),
                    A.Blur(blur_limit=3, p=0.3),
                ],
                bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]),
            )

    def load_yaml_config(self):
        yaml_path = self.data_dir / "dataset.yaml"
        with open(yaml_path, "r") as f:
            self.config = yaml.safe_load(f)

        self.class_names = self.config["names"]
        self.num_classes = len(self.class_names)

    def load_image(self, image_path: str) -> tf.Tensor:
        img = tf.io.read_file(image_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, self.img_size)
        img = tf.cast(img, tf.float32) / 255.0

        return img

    def load_labels(self, label_path: str) -> Tuple[tf.Tensor, tf.Tensor]:
        try:
            with open(label_path, "r") as f:
                labels = np.array(
                    [x.split() for x in f.read().splitlines()], dtype=np.float32
                )
        except Exception:
            labels = np.zeros((0, 5), dtype=np.float32)

        if len(labels):
            boxes = labels[:, 1:]
            classes = labels[:, 0].astype(np.int32)

        else:
            boxes = np.zeros((0, 4), dtype=np.float32)
            classes = np.zeros((0,), dtype=np.int32)

        return boxes, classes

    def create_dataset(self, split: str = "train") -> tf.data.Dataset:
        img_paths = sorted(list((self.data_dir / split / "images").glob("*.jpg")))
        label_paths = [
            str(p).replace("images", "labels").replace(".jpg", ".txt")
            for p in img_paths
        ]

        dataset = tf.data.Dataset.from_tensor_slices(
            ([str(p) for p in img_paths], label_paths)
        )

        dataset = dataset.map(
            lambda img_path, label_path: tf.py_function(
                self._load_sample,
                [img_path, label_path],
                [tf.float32, tf.float32, tf.int32],
            ),
            num_parallel_calls=tf.data.AUTOTUNE,
        )

        dataset = dataset.map(
            lambda img, boxes, classes: (
                tf.ensure_shape(img, (self.img_size[0], self.img_size[1], 3)),
                tf.ensure_shape(boxes, (None, 4)),
                tf.ensure_shape(classes, (None,)),
            )
        )

        dataset = dataset.batch(self.batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        return dataset

    def _load_sample(
        self, img_path: str, label_path: str
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        img_path = img_path.numpy().decode("utf-8")
        label_path = label_path.numpy().decode("utf-8")

        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        boxes, classes = self.load_labels(label_path)

        if self.augment and len(boxes):
            transformed = self.aug_pipeline(
                image=img, bboxes=boxes, class_labels=classes
            )
            img = transformed["image"]
            boxes = np.array(transformed["bboxes"])
            classes = np.array(transformed["class_labels"])

        img = cv2.resize(img, self.img_size)
        img = img.astype(np.float32) / 255.0

        return img, boxes, classes

    def visualize_batch(
        self, batch: Tuple[tf.Tensor, tf.Tensor, tf.Tensor], num_samples: int = 4
    ):
        images, boxes, classes = batch

        fig, axes = plt.subplots(2, 2, figsize=(12, 12))
        axes = axes.ravel()

        for idx in range(min(num_samples, len(images))):
            img = images[idx].numpy()
            img_boxes = boxes[idx].numpy()
            img_classes = classes[idx].numpy()

            img_draw = img.copy()
            for box, cls in zip(img_boxes, img_classes):
                x, y, w, h = box
                x1 = int((x - w / 2) * self.img_size[1])
                y1 = int((y - h / 2) * self.img_size[0])
                x2 = int((x + w / 2) * self.img_size[1])
                y2 = int((y + h / 2) * self.img_size[0])

                cv2.rectangle(img_draw, (x1, y1), (x2, y2), (1, 0, 0), 2)
                cv2.putText(
                    img_draw,
                    self.class_names[cls],
                    (x1, y1 - 10),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.5,
                    (1, 0, 0),
                    2,
                )

            axes[idx].imshow(img_draw)
            axes[idx].axis("off")
            axes[idx].set_title(f"Sample {idx+1}")

        plt.tight_layout()
        plt.show()

## YOLO Model Creation


In [None]:
class YOLOLoss(tf.keras.losses.Loss):
    """Custom YOLO loss implementation"""

    def __init__(self, num_classes: int, anchors: np.ndarray):
        super().__init__()

        self.num_classes = num_classes
        self.anchors = anchors
        self.ignore_thresh = 0.5
        self.lambda_coord = 5.0
        self.lambda_noobj = 0.5

    def call(self, y_true, y_pred):
        # Extract components from predictions
        pred_xy = y_pred[..., 0:2]
        pred_wh = y_pred[..., 2:4]
        pred_conf = y_pred[..., 4:5]
        pred_class = y_pred[..., 5:]

        # Extract components from ground truth
        true_xy = y_true[..., 0:2]
        true_wh = y_true[..., 2:4]
        true_conf = y_true[..., 4:5]
        true_class = y_true[..., 5:]

        # Calculate masks for objects and no-objects
        object_mask = true_conf
        no_object_mask = 1 - object_mask

        # Calculate losses
        xy_loss = self.lambda_coord * object_mask * tf.square(true_xy - pred_xy)
        wh_loss = (
            self.lambda_coord
            * object_mask
            * tf.square(
                tf.sqrt(true_wh) - tf.sqrt(tf.clip_by_value(pred_wh, 1e-10, 1.0))
            )
        )

        conf_loss = object_mask * tf.square(
            1 - pred_conf
        ) + self.lambda_noobj * no_object_mask * tf.square(0 - pred_conf)

        class_loss = (
            object_mask
            * tf.keras.losses.categorical_crossentropy(
                true_class, pred_class, from_logits=True
            )[..., tf.newaxis]
        )

        return tf.reduce_sum(xy_loss + wh_loss + conf_loss + class_loss)

In [None]:
class YOLOBlock(tf.keras.layers.Layer):
    """YOLO convolutional block"""

    def __init__(self, filters: int, kernel_size: int, strides: int = 1):
        super().__init__()
        self.conv = tf.keras.layers.Conv2D(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding="same",
            use_bias=False,
        )
        self.bn = tf.keras.layers.BatchNormalization()
        self.activation = tf.keras.layers.LeakyReLU(alpha=0.1)

    def call(self, x):
        x = self.conv(x)
        x = self.bn(x)

        return self.activation(x)

In [None]:
class YOLO(tf.keras.Model):

    def __init__(self, num_classes: int, input_shape: Tuple[int, int, int]):
        super().__init__()
        self.num_classes = num_classes

        # Backbone
        self.backbone = [
            YOLOBlock(32, 3),
            tf.keras.layers.MaxPooling2D(2, 2),
            YOLOBlock(64, 3),
            tf.keras.layers.MaxPooling2D(2, 2),
            YOLOBlock(128, 3),
            YOLOBlock(64, 1),
            YOLOBlock(128, 3),
            tf.keras.layers.MaxPooling2D(2, 2),
            YOLOBlock(256, 3),
            YOLOBlock(128, 1),
            YOLOBlock(256, 3),
            tf.keras.layers.MaxPooling2D(2, 2),
            YOLOBlock(512, 3),
            YOLOBlock(256, 1),
            YOLOBlock(512, 3),
            YOLOBlock(256, 1),
            YOLOBlock(512, 3),
        ]

        # Detection head
        self.head = [
            YOLOBlock(1024, 3),
            YOLOBlock(512, 1),
            YOLOBlock(1024, 3),
            YOLOBlock(512, 1),
            YOLOBlock(1024, 3),
            tf.keras.layers.Conv2D(3 * (5 + num_classes), 1, padding="same"),
        ]

    def call(self, x):
        for layer in self.backbone:
            x = layer(x)

        for layer in self.head:
            x = layer(x)

        batch_size = tf.shape(x)[0]
        output_shape = tf.shape(x)[1:3]
        output = tf.reshape(
            x, [batch_size, output_shape[0], output_shape[1], 3, 5 + self.num_classes]
        )

        return output

## Trainer


In [None]:
class YOLOTrainer:
    """YOLO model trainer"""

    def __init__(
        self,
        num_classes: int,
        input_shape: Tuple[int, int, int] = (640, 640, 3),
        learning_rate: float = 1e-4,
        weight_decay: float = 1e-4,
    ):
        self.num_classes = num_classes
        self.input_shape = input_shape

        self.model = YOLO(num_classes, input_shape)
        self.optimizer = tf.keras.optimizers.Adam(
            learning_rate=learning_rate, weight_decay=weight_decay
        )

        self.anchors = np.array(
            [
                [[10, 13], [16, 30], [33, 23]],
                [[30, 61], [62, 45], [59, 119]],
                [[116, 90], [156, 198], [373, 326]],
            ]
        )

        self.loss_fn = YOLOLoss(num_classes, self.anchors)
        self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)
        self.val_loss = tf.keras.metrics.Mean("val_loss", dtype=tf.float32)

    @tf.function
    def train_step(self, images, labels):
        with tf.GradientTape() as tape:
            predictions = self.model(images, training=True)
            loss = self.loss_fn(labels, predictions)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        self.train_loss.update_state(loss)
        return loss

    @tf.function
    def val_step(self, images, labels):
        predictions = self.model(images, training=False)
        loss = self.loss_fn(labels, predictions)
        self.val_loss.update_state(loss)
        return loss

    def train(
        self,
        train_dataset: tf.data.Dataset,
        val_dataset: tf.data.Dataset,
        epochs: int = 100,
        callbacks: List[tf.keras.callbacks.Callback] = None,
    ):
        if callbacks is None:
            callbacks = [
                tf.keras.callbacks.ModelCheckpoint(
                    "yolo_cattle_{epoch:02d}.h5",
                    save_best_only=True,
                    monitor="val_loss",
                ),
                tf.keras.callbacks.EarlyStopping(patience=10, monitor="val_loss"),
                tf.keras.callbacks.TensorBoard(log_dir="./logs"),
            ]

        for epoch in range(epochs):
            self.train_loss.reset_states()
            self.val_loss.reset_states()

            for images, labels in train_dataset:
                loss = self.train_step(images, labels)

            for images, labels in val_dataset:
                loss = self.val_step(images, labels)

            # Log metrics
            print(
                f"Epoch {epoch + 1}, "
                f"Train Loss: {self.train_loss.result():.4f}, "
                f"Val Loss: {self.val_loss.result():.4f}"
            )

            # Execute callbacks
            for callback in callbacks:
                callback.on_epoch_end(
                    epoch,
                    {
                        "loss": self.train_loss.result(),
                        "val_loss": self.val_loss.result(),
                    },
                )

## Main Execution


In [None]:
dataset = CattleDataset(
    data_dir="../../data/yolo", img_size=(640, 640), batch_size=32, augment=True
)

In [None]:
train_ds = dataset.create_dataset(split="train")
for batch in train_ds.take(1):
    dataset.visualize_batch(batch, num_samples=4)

In [None]:
val_ds = dataset.create_dataset(split="val")

In [None]:
trainer = YOLOTrainer(
    num_classes=dataset.num_classes, input_shape=(640, 640, 3), learning_rate=1e-4
)

In [None]:
trainer.train(train_ds, val_ds, epochs=100)