# Automation of human karyotype analysis using image segmentation and classification methods. Classification

In [None]:
!gdown 1fUWGsTT9GMmQXt9NGqIcmLgaRyWMbWzg

In [None]:
!unzip /content/Data.zip

In [None]:
import os
import glob
import random
import xml.etree.ElementTree as ET
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
xml_folder = '/content/Data/24_chromosomes_object/annotations'
image_folder = '/content/Data/24_chromosomes_object/JEPG'
output_folder = '/content/Data/24_chromosomes_cropped'

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

for xml_file in os.listdir(xml_folder):
    if not xml_file.endswith('.xml'):
        continue

    xml_path = os.path.join(xml_folder, xml_file)
    tree = ET.parse(xml_path)
    root = tree.getroot()

    filename = root.find('filename').text
    image_path = os.path.join(image_folder, filename)
    if not os.path.exists(image_path):
        print(f"Image not found: {image_path}")
        continue

    img = Image.open(image_path).convert('RGB')

    for obj in root.findall('object'):
        label = obj.find('name').text
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)

        cropped = img.crop((xmin, ymin, xmax, ymax))

        label_folder = os.path.join(output_folder, label)
        if not os.path.exists(label_folder):
            os.makedirs(label_folder)

        base_name = os.path.splitext(filename)[0]
        cropped_filename = f"{base_name}_{xmin}_{ymin}_{xmax}_{ymax}.jpg"
        cropped.save(os.path.join(label_folder, cropped_filename))

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])])

dataset = ImageFolder(root=output_folder, transform=transform)

print("Number of classes:", len(dataset.classes))
print("Classes:", dataset.classes)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

In [None]:
model = models.resnet18(pretrained=True)

num_classes = len(dataset.classes)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_loss /= val_total
    val_acc = 100.0 * val_correct / val_total
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%\n")


In [None]:
def classify_chromosome(model, image_path, transform, device, class_names):
    model.eval()
    img = Image.open(image_path).convert('RGB')
    img_t = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_t)
        _, predicted = torch.max(outputs, 1)
        predicted_class = class_names[predicted.item()]
    return predicted_class

test_image_path = '/content/Screenshot 2025-03-27 at 11.44.07.png'

predicted_label = classify_chromosome(model, test_image_path, transform, device, dataset.classes)
print("Predicted label:", predicted_label)
img = Image.open("/content/Screenshot 2025-03-27 at 11.44.07.png").convert('RGB')
img

In [None]:
img = Image.open("/content/Data/24_chromosomes_object/JEPG/103064.jpg").convert('RGB')
cropped = img.crop((285, 43, 351, 119))
cropped

In [None]:
img = Image.open("/content/Data/24_chromosomes_object/JEPG/103064.jpg").convert('RGB')
cropped = img.crop((102, 88, 175, 152))
cropped