In [None]:
# Import required libraries for training U-Net model
from cgitb import reset
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import skimage as sm
import skimage.io
from matplotlib import pyplot as plt
import tifffile
import timm
from fastai.vision.all import *
import cv2

In [None]:
# Set training hyperparameters and data paths
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
NUM_EPOCHS = 4
NUM_WORKERS = 2
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "dat/training/input/"
TRAIN_MASK_DIR = "dat/training/mask/"
VAL_IMG_DIR = "dat/validation/input/"
VAL_MASK_DIR = "dat/validation/mask/"


# Utility function to display single image
def visualize(image):
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(image)
    plt.show()


# Utility function to display multiple example images
def plot_examples(images):
    fig = plt.figure(figsize=(15, 15))
    columns = 4
    rows = 5

    for i in range(1, len(images)):
        img = images[i-1]
        fig.add_subplot(rows, columns, i)
        plt.imshow(img)
    plt.show()


# Custom dataset class for loading electric field images and petal masks
class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(
            self.mask_dir, self.images[index].replace("field", "petal"))
        image = sm.io.imread(img_path).astype(np.float32)
        mask = np.array(Image.open(mask_path), dtype=np.float32)
        mask0 = mask
        mask[mask == 255] = 1
        image = np.transpose(image, (2, 0, 1))
        images = torch.tensor(image/256).float()

        if self.transform is not None:
            transformed = self.transform(image=image[0], image0=image[1], image1=image[2], mask=mask)
            images[0] = transformed["image"]
            images[1] = transformed["image0"]
            images[2] = transformed["image1"]

            mask = transformed["mask"]

            # save_transform(image, mask0, transformed)

        return images, mask


# Utility function to save augmented training examples
def save_transform(image, mask0, transformed):

    result = np.zeros([3, 250, 250])
    result[:, 0:120, 0:120] = image
    result[0, 0:120, 130:] = np.array(transformed["image"])*255
    result[1, 0:120, 130:] = np.array(transformed["image0"])*255
    result[2, 0:120, 130:] = np.array(transformed["image1"])*255

    result[:, 130:, 0:120] = mask0*255
    result[:, 130:, 130:] = np.array(transformed["mask"])*255

    result = np.asarray(result, "uint8")
    tifffile.imwrite(f"transformResults/transform.tif", result)


# Save model checkpoint
def save_checkpoint(state, filename="models/UNetPetal.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


# Load model checkpoint
def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


# Create training and validation data loaders
def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True
):
    train_ds = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )

    return train_loader, val_loader


# Evaluate model accuracy and Dice score on validation set
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    loop = tqdm(loader)

    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loop):
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Accuracy {num_correct/num_pixels*100}%"
    )
    print(f"Dice score {dice_score/len(loader)}")
    model.train()


# Save predicted masks as images for visualization
def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
    model.eval()
    with torch.no_grad():
        for idx, (x, y) in enumerate(loader):
            x = x.to(device)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            for i in range(preds.shape[0]):
                torchvision.utils.save_image(
                    preds[i], f"{folder}pred_{i}.png"
                )
                torchvision.utils.save_image(
                    y.unsqueeze(1)[i], f"{folder}img_{i}.png")

            break

    model.train()

# Training function: forward pass, backward pass, and optimization step
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = torch.unsqueeze(targets, 1).to(device=DEVICE)

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward pass with gradient scaling
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update progress bar
        loop.set_postfix(loss=loss.item())


# Main training loop
def main():
    target3 = {'image0': 'image', 'image1': 'image', 'image2': 'image', 'mask': 'mask'}
    # Data augmentation for training set
    train_transform = A.Compose(
        [
            A.Rotate(limit=45, p=1.0, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.GaussianBlur(blur_limit=(3, 5), p=0.3),
            A.Normalize(
                mean=0,
                std=1,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
        additional_targets=target3,
    )
    # Normalization only for validation set
    val_transform = A.Compose(
        [
            A.Normalize(
                mean=0,
                std=1,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
        additional_targets=target3,
    )

    # Load pretrained ResNet34 and build U-Net model
    resnet = timm.create_model("resnet34", pretrained=True)

    m = resnet
    m = nn.Sequential(*list(m.children())[:-2])
    model = DynamicUnet(m, 1, (512, 512), norm_type=None).to(DEVICE)
#     x = cast(torch.randn(2, 3, 512, 512), TensorImage)
#     y = model(x)

    # Set loss function and optimizer
    loss_fn = nn.BCEWithLogitsLoss()  # if out_channels > 1 => cross entropy loss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(
        0.9, 0.999), eps=1e-08)

    # Create data loaders
    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,  # train_transform
        val_transform,  # val_transform
        NUM_WORKERS,
        PIN_MEMORY,
    )

    # Load existing model if specified
    if LOAD_MODEL:
        load_checkpoint(torch.load("models/UNetPetal.pth.tar"), model)
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE)
        check_accuracy(val_loader, model, device=DEVICE)

    scaler = torch.cuda.amp.GradScaler()

    # Training loop for specified number of epochs
    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # Save model checkpoint
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # Check validation accuracy
        check_accuracy(val_loader, model, device=DEVICE)
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE)

# Run training with lower learning rate
LEARNING_RATE = 5e-5
main()

100%|██████████| 27/27 [27:34<00:00, 61.29s/it, loss=0.0714]


=> Saving checkpoint


100%|██████████| 7/7 [02:09<00:00, 18.56s/it]


Accuracy TensorBase(96.7480)%
Dice score TensorBase(0.6474)


  0%|          | 0/27 [00:00<?, ?it/s]

In [9]:
LEARNING_RATE = 5e-5
main()


  0%|          | 0/112 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=0.0104] 


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.67132568359375%
Dice score 0.09293054789304733


100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=0.00212]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.685791015625%
Dice score 0.12588801980018616


100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=0.00666]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.77it/s]


Accuracy 99.7662353515625%
Dice score 0.5328267216682434


100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=0.00052]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.77it/s]


Accuracy 99.78858947753906%
Dice score 0.5896043181419373


In [10]:
LEARNING_RATE = 1e-5
main()


  0%|          | 0/112 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00219]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.80276489257812%
Dice score 0.6420930027961731


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00807]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.80888366699219%
Dice score 0.6915411353111267


100%|██████████| 112/112 [02:22<00:00,  1.28s/it, loss=0.00152]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]

Accuracy 99.81129455566406%
Dice score 0.6538491249084473



100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.0059] 


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.82037353515625%
Dice score 0.7002066969871521


In [11]:
LEARNING_RATE = 5e-6
main()


  0%|          | 0/112 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00187]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.8251724243164%
Dice score 0.6943625211715698


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00454]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.77it/s]


Accuracy 99.8243408203125%
Dice score 0.6834194660186768


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.0039] 


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.77it/s]


Accuracy 99.82669067382812%
Dice score 0.7175348997116089


100%|██████████| 112/112 [02:22<00:00,  1.28s/it, loss=0.00666]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.825439453125%
Dice score 0.6996853947639465


In [13]:
LEARNING_RATE = 1e-5
main()
LEARNING_RATE = 5e-6
main()


  0%|          | 0/112 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00398]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.82569122314453%
Dice score 0.7320536971092224


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.000298]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.77it/s]


Accuracy 99.83374786376953%
Dice score 0.7188710570335388



  0%|          | 0/112 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.005]  


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.83172607421875%
Dice score 0.7115412354469299


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.005]  


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.83241271972656%
Dice score 0.7155706882476807


In [16]:
LEARNING_RATE = 5e-6
main()


  0%|          | 0/112 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00953]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.78it/s]


Accuracy 99.83970642089844%
Dice score 0.7425869107246399


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00263]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.76it/s]


Accuracy 99.84320068359375%
Dice score 0.7368522882461548


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00136]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.75it/s]


Accuracy 99.84391021728516%
Dice score 0.7402313947677612


100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.000884]


=> Saving checkpoint


100%|██████████| 56/56 [00:20<00:00,  2.77it/s]


Accuracy 99.84194946289062%
Dice score 0.7370697259902954


In [5]:
resnet = timm.create_model("resnet34", pretrained=False)
# for name, module in resnet.named_modules():
#     print(name)
# print(resnet.conv1)
resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(
    7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# print(resnet.conv1)

m = resnet
m = nn.Sequential(*list(m.children())[:-2])
model = DynamicUnet(m, 1, (512, 512), norm_type=None).to(DEVICE)

In [6]:
from torchsummary import summary
summary(model, (10, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]          31,360
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
         MaxPool2d-4         [-1, 64, 128, 128]               0
            Conv2d-5         [-1, 64, 128, 128]          36,864
       BatchNorm2d-6         [-1, 64, 128, 128]             128
          Identity-7         [-1, 64, 128, 128]               0
              ReLU-8         [-1, 64, 128, 128]               0
          Identity-9         [-1, 64, 128, 128]               0
           Conv2d-10         [-1, 64, 128, 128]          36,864
      BatchNorm2d-11         [-1, 64, 128, 128]             128
             ReLU-12         [-1, 64, 128, 128]               0
       BasicBlock-13         [-1, 64, 128, 128]               0
           Conv2d-14         [-1, 64, 1

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

41268871