In [2]:
# Workaround for non rocm supported amd gpus

from os import putenv
putenv("HSA_OVERRIDE_GFX_VERSION", "10.3.0")
putenv("PYTORCH_ROCM_ARCH", "gfx1030")
putenv("TORCH_USE_HIP_DSA", "1")
putenv("AMD_SERIALIZE_KERNEL", "3")


In [3]:
import torch
from torch import nn
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

In [3]:
# import os

# os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"
# os.environ["TORCH_USE_HIP_DSA"] = "1"
# os.environ["AMD_SERIALIZE_KERNEL"] = "3"
# os.environ["HIP_LOG_LEVEL"] = "3"  # Set to debug level
# os.environ["HIP_VISIBLE_DEVICES"] = "0"  # Ensure only the first GPU is visible

In [5]:
# !which python
torch.cuda.empty_cache()

In [None]:
print("CUDA Available:", torch.cuda.is_available())
print("ROCm Available:", torch.version.hip is not None)
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU detected")
print(torch.version.hip)

In [6]:
# # Create a tensor and move it to the GPU
# tensor = torch.randn(3, 3).to('cuda')
# print(tensor)

# # Check if the tensor is on the GPU
# print(tensor.device)

In [7]:
def plot_img(x):
    plt.figure()
    combined_image = np.concatenate((x[0][0].detach().numpy(), x[0][10].detach().numpy(), x[0][-1].detach().numpy()), axis=1)
    plt.imshow(combined_image)

    return

def crop_image(source_tensor, target_tensor):
    source_tensor_size = source_tensor.size()[2]
    target_tensor_size = target_tensor.size()[2]

    start_x = (source_tensor.size()[2] - target_tensor.size()[2]) // 2
    start_y = (source_tensor.size()[3] - target_tensor.size()[3]) // 2

    return source_tensor[:, :, start_x:start_x + target_tensor.size()[2], start_y:start_y + target_tensor.size()[3]]

In [8]:
def double_conv(in_channel, out_channel):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channel, out_channel, kernel_size=3),
        nn.ReLU(inplace=True)
    )

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.down_conv1 = double_conv(3, 64)
        self.down_conv2 = double_conv(64, 128)
        self.down_conv3 = double_conv(128, 256)
        self.down_conv4 = double_conv(256, 512)
        self.down_conv5 = double_conv(512, 1024)

        self.tarnspose_conv_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.tarnspose_conv_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.tarnspose_conv_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.tarnspose_conv_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv1 = double_conv(1024, 512)
        self.up_conv2 = double_conv(512, 256)
        self.up_conv3 = double_conv(256, 128)
        self.up_conv4 = double_conv(128, 64)

        self.out = nn.Conv2d(64, 1, kernel_size=1)

        # self.up_conv1 = double_conv(1024, 512)
        # self.up_conv2 = double_conv()

    def forward(self, image):
        # ////////////// ENCODER //////////////
        x1 = self.down_conv1(image)
        m1 = self.max_pool(x1)
        # plot_img(x)
        x2 = self.down_conv2(m1)
        m2 = self.max_pool(x2)
        # plot_img(x)
        x3 = self.down_conv3(m2)
        m3 = self.max_pool(x3)
        # plot_img(x)
        x4 = self.down_conv4(m3)
        m4 = self.max_pool(x4)
        # plot_img(x)
        x5 = self.down_conv5(m4)
        # plot_img(x)

        # ////////////// DECODER //////////////
        x = self.tarnspose_conv_1(x5)
        # print(x.shape)
        x = self.up_conv1(torch.cat([x, crop_image(x4, x)], 1))
        # print(x.shape)
        x = self.tarnspose_conv_2(x)
        # print(x.shape)
        x = self.up_conv2(torch.cat([x, crop_image(x3, x)], 1))
        # print(x.shape)
        x = self.tarnspose_conv_3(x)
        # print(x.shape)
        x = self.up_conv3(torch.cat([x, crop_image(x2, x)], 1))
        # print(x.shape)
        x = self.tarnspose_conv_4(x)
        # print(x.shape)
        x = self.up_conv4(torch.cat([x, crop_image(x1, x)], 1))
        # print(x.shape)
        x = self.out(x)

        # print(x.shape)

        return x

In [9]:
# image_path = "./test.png"
# image = Image.open(image_path)
# print(type(image))
# image = np.array(image)
# print(image.size)
# image_tensor = torch.from_numpy(image)
# plt.imshow(image_tensor)
# image_tensor = image_tensor.float()
# image_tensor = image_tensor.permute(2, 0, 1)
# image_tensor = image_tensor.unsqueeze(0)
# print(image_tensor.size())


In [10]:
# model = UNet()
# output = model(image_tensor)


In [11]:
# combined_image = np.concatenate((output[0][0].detach().numpy(), output[0][1].detach().numpy()), axis=1)
# print(output.shape)
# plt.imshow(output[0][0].detach().numpy())

In [9]:
class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # image = Image.open(self.image_paths[idx]).convert("L")
        # mask = Image.open(self.mask_paths[idx]).convert("L")
        image = Image.open(self.image_paths[idx])
        mask = Image.open(self.mask_paths[idx])
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        
        return image, mask
    

In [10]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# test_transform = transforms.Compose([
#     transforms.Resize((836, 836)),
#     transforms.ToTensor()
# ])

In [11]:
import os
from glob import glob
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F


image_dir = "./original"
mask_dir = "./mask"

image_paths = sorted(glob(os.path.join(image_dir, '*.[jp][pn]g')))
mask_paths = sorted(glob(os.path.join(mask_dir, '*.[jp][pn]g')))

# X_train, X_val, y_train, y_val = train_test_split(image_paths, mask_paths, test_size=0.2, random_state=42)
# train_dataset = SegmentationDataset(X_train, y_train, transform=transform)
# val_dataset = SegmentationDataset(X_val, y_val, transform=transform)
# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
# val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)
###################################
# dataset = datasets.ImageFolder("path_to_dataset", transform=transforms.Compose([
#     transforms.Resize((1024, 1024)),  # Adjust size as needed
#     transforms.ToTensor(),
# ]))

# # Define the split proportions
# train_size = int(0.80 * len(dataset))
# val_size = int(0.10 * len(dataset))
# test_size = len(dataset) - train_size - val_size

# # Perform the split
# train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# # Now create DataLoaders for each set
# train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
# val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
# test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)

#########################################
dataset = SegmentationDataset(image_paths, mask_paths, transform=transform)

# Split the dataset (80% train, 10% val, 10% test)
train_size = int(0.80 * len(dataset))
val_size = int(0.10 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)

In [12]:
def train(model, train_loader, val_loader, optimizer, criterion, num_epochs=10):
    train_losses = []
    val_losses = []

    model.train()
    for epoch in range(num_epochs):
        # Training phase
        epoch_loss = 0
        model.train()  # Ensure the model is in training mode
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)
            # print(outputs.shape)
            # print(masks.shape)
            masks_resized = F.interpolate(masks, size=outputs.shape[2:], mode='bilinear', align_corners=False)
            masks_resized = masks_resized.mean(dim=1, keepdim=True)
            loss = criterion(outputs, masks_resized)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss/len(train_loader)}')

        # Validation phase
        model.eval()  # Set the model to evaluation mode
        with torch.no_grad():  # No gradient calculation during validation
            val_loss = 0
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)

                # Forward pass
                outputs = model(images)
                masks_resized = F.interpolate(masks, size=outputs.shape[2:], mode='bilinear', align_corners=False)
                masks_resized = masks_resized.mean(dim=1, keepdim=True)
                loss = criterion(outputs, masks_resized)

                val_loss += loss.item()

            avg_val_loss = epoch_loss / len(val_loader)
            val_losses.append(avg_val_loss)

        print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss/len(val_loader)}')

    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [13]:

model = UNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

# Use GPU if available
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cpu")
model = model.to(device)


In [None]:
# train(model, train_loader, val_loader, optimizer, criterion)

In [None]:
model.load_state_dict(torch.load('model_parameters.pth'))
model.eval()  # Set the model to evaluation mode

In [None]:
# test a single image 
image_path = './test.png'
img = Image.open(image_path)
plt.figure()
plt.imshow(img)
img_tensor = transform(img)
img_tensor = img_tensor.unsqueeze(0)
with torch.no_grad():  # No need to compute gradients during inference
    output = model(img_tensor)
    output_image = output.squeeze(0) 
    output_image_np = output_image.permute(1, 2, 0).numpy() 
    plt.figure()
    plt.imshow(output_image_np)


In [24]:
def dice_coefficient(preds, targets, epsilon=1e-8):
    intersection = (preds * targets).sum(dim=(2, 3))
    union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))
    dice = (2. * intersection / (union + epsilon)).mean()
    return dice

def iou_score(preds, targets, epsilon=1e-8):
    intersection = (preds * targets).sum(dim=(2, 3))
    union = (preds + targets).sum(dim=(2, 3)) - intersection
    iou = (intersection / (union + epsilon)).mean()
    return iou

def test(model, test_loader, criterion, threshold=0.5):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    dice_score_total = 0
    iou_total = 0
    all_outputs = []
    all_masks = []

    with torch.no_grad():  # No gradient calculation during testing
        for images, masks in test_loader:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)
            masks_resized = F.interpolate(masks, size=outputs.shape[2:], mode='bilinear', align_corners=False)
            masks_resized = masks_resized.mean(dim=1, keepdim=True)

            loss = criterion(outputs, masks_resized)
            test_loss += loss.item()

            # Apply a threshold to the outputs to get binary masks
            preds = (outputs > threshold).float()

            # Calculate Dice coefficient
            dice_score = dice_coefficient(preds, masks_resized)
            dice_score_total += dice_score.item()

            # Calculate IoU
            iou = iou_score(preds, masks_resized)
            iou_total += iou.item()

            # Store outputs and masks (optional)
            # all_outputs.append(preds.cpu())
            # all_masks.append(masks.cpu())

    avg_test_loss = test_loss / len(test_loader)
    avg_dice_score = dice_score_total / len(test_loader)
    avg_iou = iou_total / len(test_loader)

    # print(f'Test Loss: {avg_test_loss}')
    # print(f'Dice Score: {avg_dice_score}')
    # print(f'IoU: {avg_iou}')

    return avg_test_loss, avg_dice_score, avg_iou

def test_and_visualize(model, test_loader, criterion, threshold=0.5):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    dice_score_total = 0
    iou_total = 0

    print_more = True

    with torch.no_grad():  # No gradient calculation during testing
        for batch_idx, (images, masks) in enumerate(test_loader):
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)
            masks_resized = F.interpolate(masks, size=outputs.shape[2:], mode='bilinear', align_corners=False)
            masks_resized = masks_resized.mean(dim=1, keepdim=True)

            loss = criterion(outputs, masks_resized)
            test_loss += loss.item()

            # Apply a threshold to the outputs to get binary masks
            preds = (outputs > threshold).float()

            # Calculate Dice coefficient and IoU
            dice_score = dice_coefficient(preds, masks_resized)
            dice_score_total += dice_score.item()

            iou = iou_score(preds, masks_resized)
            iou_total += iou.item()

            # Move data to CPU for visualization
            images_cpu = images.cpu()
            masks_cpu = masks.cpu()
            preds_cpu = preds.cpu()

            # Visualize the images, original masks, and predicted masks
            if print_more:
                print_more = False
                for i in range(images_cpu.shape[0]):
                    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
                    
                    # Plot the original image
                    axes[0].imshow(images_cpu[i].permute(1, 2, 0).numpy(), cmap='gray')
                    axes[0].set_title('Original Image')
                    axes[0].axis('off')

                    # Plot the original mask
                    axes[1].imshow(masks_cpu[i][0], cmap='gray')  # Assuming masks are in grayscale
                    axes[1].set_title('Original Mask')
                    axes[1].axis('off')

                    # Plot the predicted mask
                    axes[2].imshow(preds_cpu[i][0], cmap='gray')  # Predicted mask in grayscale
                    axes[2].set_title('Predicted Mask')
                    axes[2].axis('off')

                    # Show the plots
                    plt.tight_layout()
                    plt.show()

    avg_test_loss = test_loss / len(test_loader)
    avg_dice_score = dice_score_total / len(test_loader)
    avg_iou = iou_total / len(test_loader)

    print(f'Test Loss: {avg_test_loss}')
    print(f'Dice Score: {avg_dice_score}')
    print(f'IoU: {avg_iou}')

    return avg_test_loss, avg_dice_score, avg_iou

In [None]:
test_loss, dice_score, iou = test_and_visualize(model, test_loader, criterion)
print(f'Final Test Loss: {test_loss}')
print(f'Final Dice Score: {dice_score}')
print(f'Final IoU: {iou}')