Building the Arhitecture for U-net

In [15]:
import torch
import torch.nn as nn

In [16]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [17]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNET, self).__init__()

        # Encoder (Downsampling path)
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = nn.MaxPool2d(2)
        self.conv1 = DoubleConv(64, 128)
        self.down2 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(128, 256)
        self.down3 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(256, 512)
        self.down4 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(512, 1024)

        # Decoder (Upsampling path)
        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.upconv1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upconv2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upconv3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.upconv4 = DoubleConv(128, 64)

        # Output layer
        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x2 = self.conv1(x2)
        x3 = self.down2(x2)
        x3 = self.conv2(x3)
        x4 = self.down3(x3)
        x4 = self.conv3(x4)
        x5 = self.down4(x4)
        x5 = self.conv4(x5)

        # Decoder with skip connections
        u1 = self.up1(x5)
        u1 = torch.cat([u1, x4], dim=1) # Concatenate skip connection
        u1 = self.upconv1(u1)

        u2 = self.up2(u1)
        u2 = torch.cat([u2, x3], dim=1)
        u2 = self.upconv2(u2)

        u3 = self.up3(u2)
        u3 = torch.cat([u3, x2], dim=1)
        u3 = self.upconv3(u3)

        u4 = self.up4(u3)
        u4 = torch.cat([u4, x1], dim=1)
        u4 = self.upconv4(u4)

        logits = self.outc(u4)
        return logits

###Actually making Pipeline to load the dataset from the Drive

In [18]:
!pip install -q albumentations

In [19]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

In [24]:
class TeethDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_filename = self.images[index].replace('.jpg.rf.', '_jpg.rf.') + '.png'
        mask_path = os.path.join(self.mask_dir, mask_filename)

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0 # Normalize to 0.0 and 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"].unsqueeze(0) # Add channel dimension

        return image, mask

# DRIVE_PATH = "/content/drive/MyDrive/AI Teeth"
TRAIN_IMG_DIR =  "/content/Dentalai/train/img"
TRAIN_MASK_DIR = "/content/processed_dataset/train_masks"
VAL_IMG_DIR = "/content/Dentalai/valid/img"
VAL_MASK_DIR = "/content/processed_dataset/valid_masks"

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
BATCH_SIZE = 8

# Define transformations
train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Rotate(limit=35, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

# Create datasets and dataloaders
train_ds = TeethDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, transform=train_transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

val_ds = TeethDataset(VAL_IMG_DIR, VAL_MASK_DIR, transform=val_transform)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
# Make sure the TRAIN_IMG_DIR variable is defined as it is in your main script
print(f"Checking path: {TRAIN_IMG_DIR}")

# Check if the path exists and list its contents
if os.path.exists(TRAIN_IMG_DIR):
    print("Path exists.")
    file_list = os.listdir(TRAIN_IMG_DIR)
    if not file_list:
        print("The directory is empty!")
    else:
        print(f"Found {len(file_list)} items. Here are the first 5:")
        print(file_list[:5])
else:
    print("Error: The path does not exist!")

Checking path: /content/drive/MyDrive/AI Teeth/Dataset/Dentalai/train
Path exists.
Found 2 items. Here are the first 5:
['img', 'ann']


In [13]:
# --- Copy and Unzip Data for Fast I/O ---
print("Copying datasets from Google Drive to local Colab disk...")

# Adjust this path if your .zip files are in a different location in your Drive
DRIVE_ZIP_PATH_DATA = "/content/drive/MyDrive/AI Teeth/Dataset/Dentalai-20250802T203630Z-1-001.zip"
DRIVE_ZIP_PATH_MASKS = "/content/drive/MyDrive/AI Teeth/Dataset/processed_dataset-20250802T203733Z-1-001.zip"

!cp "{DRIVE_ZIP_PATH_DATA}" .
!cp "{DRIVE_ZIP_PATH_MASKS}" .

print("Unzipping files...")
# Unzip quietly to the local disk
!unzip -q Dentalai-20250802T203630Z-1-001.zip
!unzip -q processed_dataset-20250802T203733Z-1-001.zip

print("Data setup complete! Ready for fast training.")

Copying datasets from Google Drive to local Colab disk...
Unzipping files...
Data setup complete! Ready for fast training.


In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# --- Hyperparameters & Setup ---
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8 # Make sure this is the same as in your data loader setup
NUM_EPOCHS = 10 # Start with a smaller number, you can increase later
IMAGE_HEIGHT = 256 # Make sure this is the same as in your data loader setup
IMAGE_WIDTH = 256 # Make sure this is the same as in your data loader setup

# The path to save your trained model
MODEL_SAVE_PATH = "/content/drive/MyDrive/AI Teeth/unet_teeth_v1.pth"


# --- Loss Function ---
# A combination of BCE and Dice Loss is very effective for segmentation
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        return 1 - dice

# --- Validation Function ---
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval() # Set model to evaluation mode

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()

            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)

            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    accuracy = num_correct / num_pixels * 100
    avg_dice_score = dice_score / len(loader)

    print(f"Validation Accuracy: {accuracy:.2f}%")
    print(f"Validation Dice Score: {avg_dice_score:.4f}")

    model.train() # Set model back to training mode
    return avg_dice_score

# --- Training Function ---
def train_fn(loader, model, optimizer, loss_fn_bce, loss_fn_dice):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.to(device=DEVICE)

        # Forward pass
        predictions = model(data)
        loss_bce = loss_fn_bce(predictions, targets)
        loss_dice = loss_fn_dice(predictions, targets)
        loss = loss_bce + loss_dice # Combine the two losses

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update tqdm loop description
        loop.set_postfix(loss=loss.item())

# --- Main Execution ---
def main():
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn_bce = nn.BCEWithLogitsLoss() # More stable than standard BCE
    loss_fn_dice = DiceLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_dice_score = -1.0

    for epoch in range(NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
        train_fn(train_loader, model, optimizer, loss_fn_bce, loss_fn_dice)

        # Check accuracy and get dice score on validation set
        current_dice_score = check_accuracy(val_loader, model, device=DEVICE)

        # Save the model if it has the best dice score so far
        if current_dice_score > best_dice_score:
            best_dice_score = current_dice_score
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"==> New best model saved with Dice Score: {best_dice_score:.4f}")

# Run the training process
main()


--- Epoch 1/10 ---


100%|██████████| 249/249 [06:02<00:00,  1.46s/it, loss=0.472]


Validation Accuracy: 88.80%
Validation Dice Score: 0.7562
==> New best model saved with Dice Score: 0.7562

--- Epoch 2/10 ---


100%|██████████| 249/249 [06:01<00:00,  1.45s/it, loss=0.519]


Validation Accuracy: 93.05%
Validation Dice Score: 0.8385
==> New best model saved with Dice Score: 0.8385

--- Epoch 3/10 ---


100%|██████████| 249/249 [05:59<00:00,  1.45s/it, loss=0.304]


Validation Accuracy: 93.17%
Validation Dice Score: 0.8487
==> New best model saved with Dice Score: 0.8487

--- Epoch 4/10 ---


100%|██████████| 249/249 [06:08<00:00,  1.48s/it, loss=0.291]


Validation Accuracy: 94.18%
Validation Dice Score: 0.8686
==> New best model saved with Dice Score: 0.8686

--- Epoch 5/10 ---


100%|██████████| 249/249 [06:08<00:00,  1.48s/it, loss=0.243]


Validation Accuracy: 94.50%
Validation Dice Score: 0.8764
==> New best model saved with Dice Score: 0.8764

--- Epoch 6/10 ---


100%|██████████| 249/249 [06:08<00:00,  1.48s/it, loss=0.233]


Validation Accuracy: 94.46%
Validation Dice Score: 0.8733

--- Epoch 7/10 ---


100%|██████████| 249/249 [06:09<00:00,  1.48s/it, loss=0.253]


Validation Accuracy: 94.98%
Validation Dice Score: 0.8851
==> New best model saved with Dice Score: 0.8851

--- Epoch 8/10 ---


100%|██████████| 249/249 [06:09<00:00,  1.49s/it, loss=0.372]


Validation Accuracy: 95.38%
Validation Dice Score: 0.8940
==> New best model saved with Dice Score: 0.8940

--- Epoch 9/10 ---


100%|██████████| 249/249 [06:07<00:00,  1.48s/it, loss=0.208]


Validation Accuracy: 95.62%
Validation Dice Score: 0.8995
==> New best model saved with Dice Score: 0.8995

--- Epoch 10/10 ---


100%|██████████| 249/249 [06:08<00:00,  1.48s/it, loss=0.187]


Validation Accuracy: 95.74%
Validation Dice Score: 0.9016
==> New best model saved with Dice Score: 0.9016
