In [2]:
import os
import json
import warnings


import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import segmentation_models_pytorch as smp

from PIL import Image
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm


warnings.filterwarnings(action='ignore', category=UserWarning)

In [3]:
def seed_everything(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
seed_everything()

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [12]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import cv2

class CariesDataset(Dataset):
    def __init__(self, img_dir, mask_dir, augmentation=True, device='cpu', img_size=(384, 768)):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.augmentation = augmentation
        self.device = device
        self.img_size = img_size

        # Load images and masks into memory
        self.images, self.masks = self._load_images_and_masks(img_dir, mask_dir, img_size)

        # Define transformations
        self.transform = transforms.Compose([
            transforms.ToTensor()  # Convert numpy array to tensor
        ])

        if self.augmentation:
            self.same_augmentation = transforms.Compose([
                transforms.RandomRotation(degrees=5),
                transforms.RandomHorizontalFlip(p=0.5)
            ])

            self.different_augmentation = transforms.Compose([
                transforms.RandomAdjustSharpness(2),
                transforms.ColorJitter(brightness=0.5, contrast=0.5)
            ])

    def _load_images_and_masks(self, img_dir, mask_dir, img_size):
        images = []
        masks = []
        for img_name in os.listdir(img_dir):
            img_path = os.path.join(img_dir, img_name)
            # Check if mask exists for the given image
            mask_name = img_name.replace('.jpg', '.png')  # Change if mask has different extension
            mask_path = os.path.join(mask_dir, mask_name)

            # Load the image
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Warning: Image {img_path} not found or corrupted.")
                continue  # Skip to the next image if not found

            img = cv2.resize(img, img_size)  # Resize to the desired shape
            images.append(img)

            # Load the corresponding mask
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                print(f"Warning: Mask {mask_path} not found or corrupted.")
                continue  # Skip to the next mask if not found

            mask = cv2.resize(mask, img_size)
            mask = np.where(mask > 128, 1, 0)  # Binarize the mask
            masks.append(mask)

        # Convert lists to numpy arrays and normalize images
        images = np.array(images) / 255.0  # Normalize image values to [0, 1]
        masks = np.array(masks)

        # Add an extra dimension for channels (as we have grayscale images)
        images = np.expand_dims(images, axis=-1)
        masks = np.expand_dims(masks, axis=-1)

        return images, masks

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.masks[idx]

        # Convert numpy arrays to PIL Images for augmentation
        image = Image.fromarray((image.squeeze() * 255).astype(np.uint8))
        label = Image.fromarray((label.squeeze() * 255).astype(np.uint8))

        # Apply augmentations
        if self.augmentation:
            seed = np.random.randint(0, 10000)

            torch.random.manual_seed(seed)
            image = self.same_augmentation(image)
            image = self.different_augmentation(image)

            torch.random.manual_seed(seed)
            label = self.same_augmentation(label)

        # Apply transformations
        image = self.transform(image).to(self.device)
        label = self.transform(label).to(self.device)

        # Convert label to binary mask
        label = 1. * (label != 0)

        return image, label

    def __len__(self):
        return len(self.images)

In [13]:
image_path = './DataSet/Teeth Segmentation PNG/d2/img'
labels_path = './DataSet/Teeth Segmentation PNG/d2/masks_machine'

file_names = [filename for filename in os.listdir(image_path)]
train_files, val_files = train_test_split(file_names, test_size=0.2, random_state=42)


train_image_path = [image_path + file_name for file_name in train_files]
train_mask_path = [labels_path + file_name for file_name in train_files]

eval_image_path = [image_path + file_name for file_name in val_files]
eval_mask_path = [labels_path + file_name for file_name in val_files]

In [16]:
train_dataset = CariesDataset(
    img_dir = image_path,
    mask_dir= labels_path,
    augmentation = True,
    device = device,
)

eval_dataset = CariesDataset(
    img_dir = image_path,
    mask_dir = labels_path,
    augmentation = False,
    device = device,
)

toPIL = transforms.ToPILImage()

batch_size = 8

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=1)

In [17]:
class DiceLoss(nn.Module):
    def __init__(self, smooth = 1, activation = None):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.activation = activation

    def forward(self, inputs, targets):
        if self.activation:
            inputs = self.activation(inputs)       

        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2. * intersection + self.smooth)/(inputs.sum() + targets.sum() + self.smooth)  
        
        return 1 - dice
    

    
def metric_calculate(prediction: np.ndarray, target: np.ndarray):

    target = np.uint8(target.flatten() > 0.5)
    prediction = np.uint8(prediction.flatten() > 0.5)
    TP = (prediction * target).sum()
    FN = ((1 - prediction) * target).sum()
    TN = ((1 - prediction) * (1 - target)).sum()
    FP = (prediction * (1 - target)).sum()

    acc = (TP + TN) / (TP + TN + FP + FN + 1e-4)
    iou = TP / (TP + FP + FN + 1e-4)
    dice = (2 * TP) / (2 * TP + FP + FN + 1e-4)
    pre = TP / (TP + FP + 1e-4)
    spe = TN / (FP + TN + 1e-4)
    sen = TP / (TP + FN + 1e-4)
    
    return acc, iou, dice, pre, spe, sen

In [18]:
model = smp.UnetPlusPlus(
    encoder_name = 'efficientnet-b0',        
    encoder_weights = 'imagenet',     
    in_channels = 1,                  
    classes = 1,
).to(device)

model_name = 'UNetEfficientnetB0'

criterion = DiceLoss(activation=F.sigmoid)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

num_epoch = 50

In [19]:
print (f'Training {model_name} start.')

IoU_max = 0.
losses_train, losses_val = [], []
metrics = []

for epoch in tqdm(range(num_epoch)):
    current_train_loss, current_val_loss = 0., 0.
    current_metric = np.zeros(6)

    model.train()
    for images, labels in train_dataloader:
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        current_train_loss += loss.item() / len(train_dataloader)

    model.eval()
    with torch.no_grad():
        for images, labels in eval_dataloader:
            logits = model(images)
            loss = criterion(logits, labels)

            current_val_loss += loss.item() / len(eval_dataloader)
            current_metric += np.array(metric_calculate(
                logits.cpu().detach().numpy(), 
                labels.cpu().detach().numpy())) / len(eval_dataloader)

    losses_train.append(current_train_loss)
    losses_val.append(current_val_loss)
    metrics.append(current_metric.tolist())

    if IoU_max < metrics[-1][1]:
        torch.save(model, f'{model_name}-best.pth')
        IoU_max = metrics[-1][1]

    print (f'Epoch: {epoch + 1}, train_loss: {losses_train[-1]:.4f}, val_loss: {losses_val[-1]:.4f}, IoU: {metrics[-1][1]:.4f}')


log = {}
log['train_loss'] = losses_train 
log['eval_loss'] = losses_val
log['metric'] = metrics
log['best_score'] = IoU_max

torch.save(model, f'{model_name}-last.pth')

with open(f'log.txt', 'w') as outfile:
    json.dump(log, outfile) 

torch.cuda.empty_cache()

print ('- - ' * 30)
print (f'Training {model_name} done. Best IoU: {IoU_max:.4f}.')
print ('- - ' * 30)

Training UNetEfficientnetB0 start.


  0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 