# Automation of human karyotype analysis using image segmentation and classification methods. Segmentation

# Imports

In [None]:
!gdown 1fUWGsTT9GMmQXt9NGqIcmLgaRyWMbWzg

In [None]:
!unzip /content/Data.zip

# Segmentation

## Faster R-CNN and Retinanet

In [None]:
!pip install ultralytics
import os
import numpy as np
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms as T
from PIL import Image
import xml.etree.ElementTree as ET
from torchvision.models.detection import retinanet_resnet50_fpn
import pandas as pd
import matplotlib.pyplot as plt
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from functools import partial
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
import matplotlib.patches as patches
import math
import cv2
from ultralytics import YOLO
import time
from collections import defaultdict
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import glob
import random
import shutil

In [4]:
class ChromosomeDataset(torch.utils.data.Dataset):
    def __init__(self, root, image_folder="images", annotation_folder="annotations", transforms=None):
        self.root = root
        self.image_folder = image_folder
        self.annotation_folder = annotation_folder
        self.transforms = transforms

        self.imgs = list(sorted(os.listdir(os.path.join(root, image_folder))))
        self.annotations = list(sorted(os.listdir(os.path.join(root, annotation_folder))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.image_folder, self.imgs[idx])
        img = Image.open(img_path).convert("RGB")

        annot_path = os.path.join(self.root, self.annotation_folder, self.annotations[idx])
        tree = ET.parse(annot_path)
        root_xml = tree.getroot()

        boxes = []
        labels = []

        for obj in root_xml.findall("object"):
            name = obj.find("name").text.strip().lower()
            labels.append(1)

            bndbox = obj.find("bndbox")
            xmin = float(bndbox.find("xmin").text)
            ymin = float(bndbox.find("ymin").text)
            xmax = float(bndbox.find("xmax").text)
            ymax = float(bndbox.find("ymax").text)
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [5]:
def get_model_faster_rcnn(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def get_model_retinanet(num_classes):
    # model = retinanet_resnet50_fpn(weights=None, weights_backbone=None)

    # # Get the number of input channels from the final classification layer
    # in_channels = model.head.classification_head.cls_logits.in_channels
    # # Number of anchors (default is 9 for each feature map location)
    # num_anchors = model.head.classification_head.num_anchors

    # # Replace the classification head with a new one (for our custom num_classes)
    # model.head.classification_head = RetinaNetClassificationHead(
    #     in_channels=in_channels,
    #     num_anchors=num_anchors,
    #     num_classes=num_classes
    # )

    model = torchvision.models.detection.retinanet_resnet50_fpn_v2(
        weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1
    )
    num_anchors = model.head.classification_head.num_anchors
    model.head.classification_head = RetinaNetClassificationHead(
        in_channels=256,
        num_anchors=num_anchors,
        num_classes=num_classes,
        norm_layer=partial(torch.nn.GroupNorm, 32)
    )
    return model

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    running_loss = 0.0

    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        running_loss += losses.item()

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

    avg_loss = running_loss / len(data_loader)
    print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")

In [6]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

root_dir = "/content/Data/single_chromosomes_object"
image_folder = "JEPG"
annotation_folder = "anntations"

dataset = ChromosomeDataset(root=root_dir, image_folder=image_folder, annotation_folder=annotation_folder, transforms=get_transform(train=True))
dataset_test = ChromosomeDataset(root=root_dir, image_folder=image_folder, annotation_folder=annotation_folder, transforms=get_transform(train=False))

indices = torch.randperm(len(dataset)).tolist()
split_index = int(0.8 * len(indices))
dataset = torch.utils.data.Subset(dataset, indices[:split_index])
dataset_test = torch.utils.data.Subset(dataset_test, indices[split_index:])

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=lambda x: tuple(zip(*x))
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=lambda x: tuple(zip(*x))
)

In [None]:
num_classes = 2
model = get_model_faster_rcnn(num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

num_epochs = 10
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader, device, epoch)

torch.save(model.state_dict(), "fasterrcnn_chromosomes.pth")

In [None]:
num_classes = 2
model = get_model_retinanet(num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

num_epochs = 10
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader, device, epoch)

torch.save(model.state_dict(), "retinanet_chromosomes.pth")

In [9]:
def iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    return interArea / float(boxAArea + boxBArea - interArea + 1e-6)

def evaluate_image(gt_boxes, pred_boxes, iou_threshold=0.5):

    pred_boxes = sorted(pred_boxes, key=lambda x: x['score'], reverse=True)

    matched_gt = set()
    tp = 0
    for pred in pred_boxes:
        best_iou = 0
        best_idx = -1
        for i, gt_box in enumerate(gt_boxes):
            curr_iou = iou(pred['box'], gt_box)
            if curr_iou > best_iou:
                best_iou = curr_iou
                best_idx = i

        if best_iou >= iou_threshold and best_idx not in matched_gt:
            tp += 1
            matched_gt.add(best_idx)

    fp = len(pred_boxes) - tp
    fn = len(gt_boxes) - tp

    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6) if (precision + recall) > 0 else 0
    return precision, recall, f1

def compute_ap_single_class(all_preds, all_gts, iou_threshold=0.5):
    pred_list = []
    for img_id, preds in all_preds.items():
        for p in preds:
            pred_list.append({'image_id': img_id, 'box': p['box'], 'score': p['score']})
    pred_list = sorted(pred_list, key=lambda x: x['score'], reverse=True)

    total_gts = sum(len(gts) for gts in all_gts.values())

    tp_list = []
    fp_list = []

    matched = {img_id: np.zeros(len(all_gts[img_id])) for img_id in all_gts}

    for pred in pred_list:
        img_id = pred['image_id']
        best_iou = 0
        best_idx = -1
        for i, gt_box in enumerate(all_gts[img_id]):
            curr_iou = iou(pred['box'], gt_box)
            if curr_iou > best_iou:
                best_iou = curr_iou
                best_idx = i

        if best_iou >= iou_threshold:
            if matched[img_id][best_idx] == 0:
                tp_list.append(1)
                fp_list.append(0)
                matched[img_id][best_idx] = 1
            else:
                tp_list.append(0)
                fp_list.append(1)
        else:
            tp_list.append(0)
            fp_list.append(1)

    tp_array = np.cumsum(tp_list)
    fp_array = np.cumsum(fp_list)
    precisions = tp_array / (tp_array + fp_array + 1e-6)
    recalls = tp_array / (total_gts + 1e-6)

    ap = 0.0
    for t in np.linspace(0, 1, 11):
        p = np.max(precisions[recalls >= t]) if np.any(recalls >= t) else 0
        ap += p / 11.0

    return ap


def evaluate_detection_single_class(model, data_loader, device,
                                    iou_threshold=0.5, score_threshold=0.5):
    model.eval()
    all_ground_truths = {}
    all_predictions = {}

    with torch.no_grad():
        for idx, (images, targets) in enumerate(data_loader):
            image = images[0].to(device)
            gt_boxes = targets[0]['boxes'].cpu().tolist()

            outputs = model([image])
            pred_boxes = []
            for box, score in zip(outputs[0]['boxes'].cpu().tolist(),
                                  outputs[0]['scores'].cpu().tolist()):
                if score >= score_threshold:
                    pred_boxes.append({'box': box, 'score': score})

            all_ground_truths[idx] = gt_boxes
            all_predictions[idx] = pred_boxes

    metrics_per_image = {}
    for img_id in all_ground_truths.keys():
        prec, rec, f1 = evaluate_image(all_ground_truths[img_id],
                                       all_predictions.get(img_id, []),
                                       iou_threshold=iou_threshold)
        metrics_per_image[img_id] = {
            'precision': prec,
            'recall': rec,
            'f1': f1
        }

    ap = compute_ap_single_class(all_predictions, all_ground_truths, iou_threshold=iou_threshold)

    df = pd.DataFrame.from_dict(metrics_per_image, orient='index')
    print("\nPer-Image Metrics (IoU >= {:.2f}):".format(iou_threshold))
    print(df)
    print(f"\nOverall AP@IoU={iou_threshold:.2f}: {ap:.3f}")

In [None]:
num_classes = 2
model_faster_rcnn = get_model_faster_rcnn(num_classes)
model_faster_rcnn.load_state_dict(torch.load("fasterrcnn_chromosomes.pth", map_location=device))
model_faster_rcnn.to(device)

evaluate_detection_single_class(
        model=model_faster_rcnn,
        data_loader=data_loader_test,
        device=device,
        iou_threshold=0.5,
        score_threshold=0.5)

In [None]:
num_classes = 2
model_retinanet = get_model_retinanet(num_classes)
model_retinanet.load_state_dict(torch.load("retinanet_chromosomes.pth", map_location=device))
model_retinanet.to(device)

evaluate_detection_single_class(
        model=model_retinanet,
        data_loader=data_loader_test,
        device=device,
        iou_threshold=0.5,
        score_threshold=0.5)

In [15]:
def convert_to_coco_api(dataset):
    coco_ds = {
        "images": [],
        "categories": [],
        "annotations": []
    }

    coco_ds["categories"].append({
        "id": 1,
        "name": "chromosomes"
    })

    annotation_id = 1
    for img_idx in range(len(dataset)):
        _, target = dataset[img_idx]

        image_info = {
            "id": img_idx,
            "file_name": str(img_idx)
        }
        coco_ds["images"].append(image_info)

        boxes = target["boxes"]
        labels = target["labels"]

        boxes = boxes.numpy()
        labels = labels.numpy()

        for box, label in zip(boxes, labels):
            xmin, ymin, xmax, ymax = box
            w = xmax - xmin
            h = ymax - ymin

            ann = {
                "id": annotation_id,
                "image_id": img_idx,
                "category_id": int(label),
                "bbox": [float(xmin), float(ymin), float(w), float(h)],
                "area": float(w * h),
                "iscrowd": 0
            }
            coco_ds["annotations"].append(ann)
            annotation_id += 1

    coco = COCO()
    coco.dataset = coco_ds
    coco.createIndex()
    return coco

def prepare_predictions(predictions, img_ids, label_offset=0):
    coco_results = []
    for img_id, prediction in zip(img_ids, predictions):
        boxes = prediction["boxes"].cpu().numpy()
        scores = prediction["scores"].cpu().numpy()
        labels = prediction["labels"].cpu().numpy()

        for box, score, label in zip(boxes, scores, labels):
            xmin, ymin, xmax, ymax = box
            w = xmax - xmin
            h = ymax - ymin
            coco_results.append({
                "image_id": img_id,
                "category_id": int(label + label_offset),
                "bbox": [float(xmin), float(ymin), float(w), float(h)],
                "score": float(score)
            })
    return coco_results

def coco_evaluate(model, data_loader, device, label_offset=0):
    dataset = data_loader.dataset
    if isinstance(dataset, torch.utils.data.Subset):
        subset = dataset
        dataset = subset.dataset
        subset_indices = subset.indices
    else:
        subset_indices = range(len(dataset))

    coco_gt = convert_to_coco_api(dataset)

    model.eval()
    results = []
    img_ids = []

    for i, (images, targets) in enumerate(data_loader):
        images = list(img.to(device) for img in images)

        batch_indices = subset_indices[i * data_loader.batch_size : i * data_loader.batch_size + len(images)]

        with torch.no_grad():
            outputs = model(images)

        results.extend(prepare_predictions(outputs, batch_indices, label_offset=label_offset))
        img_ids.extend(batch_indices)

    coco_dt = coco_gt.loadRes(results) if results else COCO()

    coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
    coco_eval.params.imgIds = list(img_ids)
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

    return coco_eval

In [None]:
num_classes = 2
model_faster_rcnn = get_model_faster_rcnn(num_classes)
model_faster_rcnn.load_state_dict(torch.load("fasterrcnn_chromosomes.pth", map_location=device))
model_faster_rcnn.to(device)

coco_eval_faster_rcnn = coco_evaluate(model_faster_rcnn, data_loader_test, device, label_offset=0)

In [None]:
num_classes = 2
model_retinanet = get_model_retinanet(num_classes)
model_retinanet.load_state_dict(torch.load("retinanet_chromosomes.pth", map_location=device))
model_retinanet.to(device)

coco_eval_retinanet = coco_evaluate(model_retinanet, data_loader_test, device, label_offset=0)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def visualize_sample(model, dataset, idx=0, threshold=0.5, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    model.eval()

    img, target = dataset[idx]
    image_np = img.permute(1, 2, 0).cpu().numpy()
    fig, axs = plt.subplots(1, 3, figsize=(20, 7))

    axs[0].imshow(image_np)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(image_np)
    axs[1].set_title("Ground Truth Boxes")

    for box in target["boxes"]:
        xmin, ymin, xmax, ymax = box.tolist()
        width = xmax - xmin
        height = ymax - ymin
        rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor="green", facecolor="none")
        axs[1].add_patch(rect)
    axs[1].axis("off")

    img_tensor = img.to(device)
    with torch.no_grad():
        prediction = model([img_tensor])[0]

    axs[2].imshow(image_np)
    axs[2].set_title("Predicted Boxes")

    for box, score in zip(prediction["boxes"], prediction["scores"]):
        if score < threshold:
            continue
        xmin, ymin, xmax, ymax = box.cpu().numpy()
        width = xmax - xmin
        height = ymax - ymin
        rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor="red", facecolor="none")
        axs[2].add_patch(rect)
    axs[2].axis("off")

    plt.tight_layout()
    plt.show()


visualize_sample(model_faster_rcnn, dataset_test, idx=0, threshold=0.5)
visualize_sample(model_retinanet, dataset_test, idx=0, threshold=0.5)

## YOLO

In [None]:
def convert_voc_to_yolo(xml_file, labels_dir, image_dir):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    filename = root.find('filename').text
    image_path = os.path.join(image_dir, filename)

    size_tag = root.find('size')
    img_width = int(size_tag.find('width').text)
    img_height = int(size_tag.find('height').text)

    base_name = os.path.splitext(filename)[0]
    txt_file = os.path.join(labels_dir, base_name + '.txt')

    lines = []
    for obj in root.findall('object'):
        class_name = obj.find('name').text
        class_id = 0

        bndbox = obj.find('bndbox')
        xmin = float(bndbox.find('xmin').text)
        ymin = float(bndbox.find('ymin').text)
        xmax = float(bndbox.find('xmax').text)
        ymax = float(bndbox.find('ymax').text)

        x_center = ((xmin + xmax) / 2.0) / img_width
        y_center = ((ymin + ymax) / 2.0) / img_height
        w = (xmax - xmin) / img_width
        h = (ymax - ymin) / img_height

        line = f"{class_id} {x_center:.6f} {y_center:.6f} {w:.6f} {h:.6f}"
        lines.append(line)

    with open(txt_file, 'w') as f:
        for line in lines:
            f.write(line + '\n')


def prepare_dataset(data_dir, train_split=0.8):
    images_dir = os.path.join(data_dir, "JEPG")
    annotations_dir = os.path.join(data_dir, "anntations")

    output_images_train = os.path.join(data_dir, "images", "train")
    output_images_val = os.path.join(data_dir, "images", "val")
    output_labels_train = os.path.join(data_dir, "labels", "train")
    output_labels_val = os.path.join(data_dir, "labels", "val")

    os.makedirs(output_images_train, exist_ok=True)
    os.makedirs(output_images_val, exist_ok=True)
    os.makedirs(output_labels_train, exist_ok=True)
    os.makedirs(output_labels_val, exist_ok=True)

    xml_files = glob.glob(os.path.join(annotations_dir, "*.xml"))

    random.shuffle(xml_files)
    train_count = int(len(xml_files) * train_split)
    train_xmls = xml_files[:train_count]
    val_xmls   = xml_files[train_count:]

    def move_and_convert(xml_list, images_subdir, labels_subdir):
        for xml_file in xml_list:
            convert_voc_to_yolo(xml_file, labels_subdir, images_dir)

            tree = ET.parse(xml_file)
            root = tree.getroot()
            filename = root.find('filename').text
            src_img_path = os.path.join(images_dir, filename)
            dst_img_path = os.path.join(images_subdir, filename)
            if os.path.exists(src_img_path):
                shutil.copy2(src_img_path, dst_img_path)

    move_and_convert(train_xmls, output_images_train, output_labels_train)
    move_and_convert(val_xmls,   output_images_val,   output_labels_val)

    data_yaml = os.path.join(data_dir, "data.yaml")
    with open(data_yaml, 'w') as f:
        f.write("train: {}/images/train\n".format(data_dir))
        f.write("val: {}/images/val\n".format(data_dir))
        f.write("names: ['chromosomes']\n")


def train_yolo(data_dir, model_size='n', epochs=50, imgsz=640):
    data_yaml = os.path.join(data_dir, "data.yaml")
    model_name = f"yolov8{model_size}.pt" #'n', 's', 'm', 'l', 'x'
    model = YOLO(model_name)

    model.train(
        data=data_yaml,
        epochs=epochs,
        imgsz=imgsz,
        project=os.path.join(data_dir, "runs"),
        name=f"yolo_chromosomes_{model_size}",
        exist_ok=True
    )

DATA_DIR = "/content/Data/single_chromosomes_object"
prepare_dataset(DATA_DIR, train_split=0.8)

train_yolo(
    data_dir=DATA_DIR,
    model_size='n',
    epochs=50,
    imgsz=640
)

In [None]:
def run_inference(data_dir, model_path, test_image_path):
    model = YOLO(model_path)
    results = model.predict(source=test_image_path, conf=0.25, save=True)

best_model_path = os.path.join(DATA_DIR, "runs", "yolo_chromosomes_n", "weights", "best.pt")

test_image = os.path.join(DATA_DIR, "images", "train", "103064.jpg")

run_inference(
    data_dir=DATA_DIR,
    model_path=best_model_path,
    test_image_path=test_image
)

In [None]:
def parse_voc_annotations(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    boxes = []
    for obj in root.findall('object'):
        bndbox = obj.find('bndbox')
        xmin = int(float(bndbox.find('xmin').text))
        ymin = int(float(bndbox.find('ymin').text))
        xmax = int(float(bndbox.find('xmax').text))
        ymax = int(float(bndbox.find('ymax').text))
        boxes.append([xmin, ymin, xmax, ymax])
    return boxes


def draw_boxes_cv2(image, boxes, color=(0, 255, 0), label_text=None):
    for box in boxes:
        x1, y1, x2, y2 = box
        cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness=2)
        if label_text:
            cv2.putText(image, label_text, (x1, max(0, y1 - 5)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
    return image


def plot_original_gt_pred(image_path, xml_path, model_path, conf_threshold=0.25):
    original_bgr = cv2.imread(image_path)
    if original_bgr is None:
        raise FileNotFoundError(f"Could not read image at: {image_path}")

    gt_bgr = original_bgr.copy()
    pred_bgr = original_bgr.copy()

    if xml_path and os.path.exists(xml_path):
        gt_boxes = parse_voc_annotations(xml_path)
        gt_bgr = draw_boxes_cv2(gt_bgr, gt_boxes, color=(0, 255, 0))
    else:
        print("No valid XML file found")

    model = YOLO(model_path)
    results = model.predict(source=image_path, conf=conf_threshold)

    pred_boxes = []
    for r in results:
        for box in r.boxes:
            x1, y1, x2, y2 = box.xyxy[0]
            pred_boxes.append([int(x1), int(y1), int(x2), int(y2)])

    pred_bgr = draw_boxes_cv2(pred_bgr, pred_boxes, color=(0, 0, 255))

    original_rgb = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2RGB)
    gt_rgb = cv2.cvtColor(gt_bgr, cv2.COLOR_BGR2RGB)
    pred_rgb = cv2.cvtColor(pred_bgr, cv2.COLOR_BGR2RGB)

    fig, axs = plt.subplots(1, 3, figsize=(20, 8))

    axs[0].imshow(original_rgb)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(gt_rgb)
    axs[1].set_title("Ground Truth Boxes")
    axs[1].axis("off")

    axs[2].imshow(pred_rgb)
    axs[2].set_title("Predicted Boxes")
    axs[2].axis("off")

    plt.tight_layout()
    plt.show()

best_model_path = "/content/Data/single_chromosomes_object/runs/yolo_chromosomes_n/weights/best.pt"
test_image_path = "/content/Data/single_chromosomes_object/images/train/103064.jpg"
xml_path = "/content/Data/single_chromosomes_object/anntations/103064.xml"

plot_original_gt_pred(
    image_path=test_image_path,
    xml_path=xml_path,
    model_path=best_model_path,
    conf_threshold=0.25
)