In [1]:
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
!pip install onnx onnxruntime

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np

# ---------------- CONFIG ---------------- #
TRAIN_IMG_DIR = "/content/drive/MyDrive/dataset/train/images"
TRAIN_MASK_DIR = "/content/drive/MyDrive/dataset/train/masks"
VALID_IMG_DIR = "/content/drive/MyDrive/dataset/valid/images"
VALID_MASK_DIR = "/content/drive/MyDrive/dataset/valid/masks"

NUM_CLASSES = 3  # background + 3 classes
IMG_SIZE = 640
BATCH_SIZE = 4
EPOCHS = 40
LR = 1e-3

# Device selection: prioritize MPS → CUDA → CPU
if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

print(f"⚡ Using device: {DEVICE}")


# ---------------- Dataset ---------------- #
class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, img_size=640):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.images = os.listdir(img_dir)
        self.img_size = img_size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '.png'))

        # Open
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # ✅ Resize to fixed 640x640
        image = image.resize((self.img_size, self.img_size), Image.BILINEAR)
        mask = mask.resize((self.img_size, self.img_size), Image.NEAREST)

        # Convert to tensors
        image = transforms.ToTensor()(image)  # [C, H, W]
        mask = torch.from_numpy(np.array(mask, dtype=np.int64))  # [H, W]

        return image, mask


# Dataset + DataLoader
train_dataset = SegmentationDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, img_size=IMG_SIZE)
valid_dataset = SegmentationDataset(VALID_IMG_DIR, VALID_MASK_DIR, img_size=IMG_SIZE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)


# ---------------- U-Net Model ---------------- #
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_classes):
        super(UNet, self).__init__()
        self.enc1 = DoubleConv(3, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.out(d1)  # logits


# ---------------- Metrics ---------------- #
def dice_score(pred, target, num_classes):
    pred = torch.argmax(pred, dim=1)
    dice = 0
    for cls in range(num_classes):
        pred_cls = (pred == cls).float()
        target_cls = (target == cls).float()
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        if union.item() > 0:
            dice += (2. * intersection) / union
    return dice / num_classes


def iou_score(pred, target, num_classes):
    pred = torch.argmax(pred, dim=1)
    iou = 0
    for cls in range(num_classes):
        pred_cls = (pred == cls).float()
        target_cls = (target == cls).float()
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum() - intersection
        if union.item() > 0:
            iou += intersection / union
    return iou / num_classes


# ---------------- Training ---------------- #
model = UNet(n_classes=NUM_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

best_dice = 0.0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss, val_dice, val_iou = 0, 0, 0
    with torch.no_grad():
        for imgs, masks in valid_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            val_dice += dice_score(outputs, masks, NUM_CLASSES)
            val_iou += iou_score(outputs, masks, NUM_CLASSES)

    val_dice /= len(valid_loader)
    val_iou /= len(valid_loader)

    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {train_loss/len(train_loader):.4f} | "
          f"Val Loss: {val_loss/len(valid_loader):.4f} | "
          f"Dice: {val_dice:.4f} | "
          f"IoU: {val_iou:.4f}")

    # Save best model
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), "unet_arecanut_best.pth")
        print("💾 Best model updated!")

# Save final model
torch.save(model.state_dict(), "unet_arecanut_final.pth")
print("✅ Training finished. Final model saved.")

# ---------------- Export to ONNX ---------------- #
model = UNet(n_classes=NUM_CLASSES).to(DEVICE)
model.load_state_dict(torch.load("unet_arecanut_final.pth", map_location=DEVICE))
model.eval()

dummy_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=DEVICE)
onnx_path = "unet_arecanut.onnx"

torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size", 2: "height", 3: "width"},
        "output": {0: "batch_size", 2: "height", 3: "width"},
    }
)

print(f"✅ ONNX model saved as {onnx_path}")


Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118
Collecting onnx
  Downloading onnx-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m115.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m119.6 MB/s[

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/dataset/train/images'

In [None]:
import os
import re

folder = "/content/drive/MyDrive/data ready/train/train_masks_class"  # change if needed

for filename in sorted(os.listdir(folder)):
    if filename.lower().endswith((".png")):
        # Match "img_<number>_jpg.rf.<random>.jpg"
        match = re.match(r"(\d)\_id.png", filename)
        print(match)
        print(match.group(1))
        # if match:
        #     old_path = os.path.join(folder, filename)
        #     new_name = f"{match.group(1)}.jpg"
        #     new_path = os.path.join(folder, new_name)

        #     if not os.path.exists(new_path):  # prevent overwrite
        #         os.rename(old_path, new_path)
        #         print(f"Renamed: {filename} -> {new_name}")
        #     else:
        #         print(f"Skipped (already exists): {new_name}")