In [None]:
import torch

In [None]:
num_gpus = torch.cuda.device_count() 
num_gpus

In [None]:
import segmentation_models_pytorch as smp
from PIL import Image
import pandas as pd
import torch
# from unet import UNet
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
import torchvision.io as io
from scipy.io import loadmat
from scipy.io import whosmat
import os
import numpy as np
import pydicom
import torchvision
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision import transforms as T
import timm
import random

In [None]:
transform = T.Compose([T.Resize(224),
                   T.ToTensor(),
                   T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD )])

In [None]:
img_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
#         transforms.Normalize([0.5], [0.3])
    ])

mask_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
#         transforms.Normalize([0.1], [0.1])
    ])

In [None]:
class CTDataset(Dataset):
    def __init__(self, scans_path, ground_truth_path, transform=transform, background_ratio=0.3):
        self.scans_path = scans_path
        self.ground_truth_path = ground_truth_path
        self.transform = transform

        self.samples = []
        self.background_samples = []
        self.object_samples = []

        # Get a list of all patient folders
        patients = sorted(os.listdir(scans_path))

        for patient_folder in patients:
            patient_path = os.path.join(scans_path, patient_folder)
            ground_truth_file = os.path.join(ground_truth_path, f"{patient_folder}.mat")

            # Load ground truth masks from .mat file
            ground_truth_data = loadmat(ground_truth_file)
            masks = ground_truth_data["Mask"]

            # Get list of .dcm files for the patient
            dcm_files = sorted([f for f in os.listdir(patient_path) if f.endswith(".dcm")])

            for i, dcm_file in enumerate(dcm_files):
                dcm_path = os.path.join(patient_path, dcm_file)
                mask = masks[:, :, i]

                if np.sum(mask) == 0:
                    self.background_samples.append((dcm_path, mask))
                else:
                    self.object_samples.append((dcm_path, mask))

        # Calculate the number of background samples to keep
        num_background_samples = int(len(self.background_samples))

        # Shuffle and truncate the background_samples list
        random.shuffle(self.background_samples)
        self.background_samples = self.background_samples[:num_background_samples]

        # Combine object and background samples into the final list
        self.samples = self.background_samples + self.object_samples
        random.shuffle(self.samples)
        
        print("list object:",len(self.object_samples))
        print("list background:",len(self.background_samples))
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        dcm_path, mask = self.samples[idx]

        # Read .dcm file
        dcm = pydicom.read_file(dcm_path).pixel_array
        img = Image.fromarray(np.uint8(dcm * 255), 'L')
        img = img.convert('RGB')

        mask = Image.fromarray(np.uint8(mask), 'L')
        # mask = mask.convert('RGB')

        # Apply transformation, if specified
        if self.transform:
            image = self.transform(img)
            mask = self.transform(mask)

        return image, mask
    

In [None]:
# Initialize your custom train dataset
train_scans_path = "data/train_scans"
train_ground_truth_path = "data/train_truth"
transform = ToTensor()
train_dataset = CTDataset(train_scans_path, train_ground_truth_path, transform=transform)

# Create a DataLoader to handle batching
batch_size = 12
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

len(train_dataset)

In [None]:
img, mask = train_dataset[100]
mask.size()

In [None]:
image100, mask100 = train_dataset[18]
image60, mask60 = train_dataset[60]
image17, mask17 = train_dataset[61]

In [None]:
def visualize(image_pairs):
    """Plot image-mask pairs vertically below each other."""
    num_pairs = len(image_pairs)
    plt.figure(figsize=(10, 5 * num_pairs))
    for i, (image, mask) in enumerate(image_pairs):
        plt.subplot(num_pairs, 2, i * 2 + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(f"Image {i+1}")
        plt.imshow(image)
        
        plt.subplot(num_pairs, 2, i * 2 + 2)
        plt.xticks([])
        plt.yticks([])
        plt.title(f"Mask {i+1}")
        plt.imshow(mask)
    plt.show()

In [None]:
image100 = np.transpose(image100,(1,2,0))
mask100 = np.transpose(mask100, (1,2,0))
image55 = np.transpose(image60,(1,2,0))
mask55 = np.transpose(mask60, (1,2,0))
image17 = np.transpose(image17,(1,2,0))
mask17 = np.transpose(mask17, (1,2,0))


image_pairs = [(image100, mask100), (image55,mask55),(image17,mask17)]
visualize(image_pairs)

In [None]:
mask17.shape

In [None]:
image100.shape

In [None]:
device = 'cuda'

In [None]:
class ConvBlock(nn.Module):
    """
    Helper module that consists of a Conv -> BN -> ReLU
    """

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x


class Bridge(nn.Module):
    """
    This is the middle layer of the UNet which just consists of some
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)


class UpBlockForUNetWithResNet50(nn.Module):
    """
    Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
    """

    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """
        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x


class UNet(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes=1):
        super().__init__()
        # ResNet50_Weights.IMAGENET1K_V1
        resnet = torchvision.models.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        down_blocks = []
        up_blocks = []
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]
        self.input_pool = list(resnet.children())[3]
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                down_blocks.append(bottleneck)
        self.down_blocks = nn.ModuleList(down_blocks)
        self.bridge = Bridge(2048, 2048)
        up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
        up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
        up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
                                                    up_conv_in_channels=256, up_conv_out_channels=128))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
                                                    up_conv_in_channels=128, up_conv_out_channels=64))

        self.up_blocks = nn.ModuleList(up_blocks)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)

                                                
    

    def forward(self, x, with_output_feature_map=False):
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)
        pre_pools[f"layer_1"] = x
        x = self.input_pool(x)

        for i, block in enumerate(self.down_blocks, 2):
            x = block(x)
            if i == (UNet.DEPTH - 1):
                continue
            pre_pools[f"layer_{i}"] = x

        x = self.bridge(x)

        for i, block in enumerate(self.up_blocks, 1):
            key = f"layer_{UNet.DEPTH - 1 - i}"
            x = block(x, pre_pools[key])
        output_feature_map = x
        x = self.out(x)
        del pre_pools
        if with_output_feature_map:
            return x, output_feature_map
        else:
            return x

In [None]:
model = UNet().to(device)

model= nn.DataParallel(model)


In [None]:
# pos_weight = torch.tensor([6]).to(device)


criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=0.0001)

In [None]:
def dice_coefficient(y_true, y_pred):
    smooth = 1.0
    y_true_f = y_true.view(-1)
    y_pred_f = y_pred.view(-1)
    intersection = (y_true_f * y_pred_f).sum()
    return (2.0 * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)

def iou(pred, target, n_classes=1):
    pred = pred.view(-1)
    target = target.view(-1)
    pred_inds = pred > 0.5
    target_inds = target > 0.5
    intersection = (pred_inds & target_inds).float().sum().item()
    union = (pred_inds | target_inds).float().sum().item()
    iou_score = (intersection + 1e-6) / (union + 1e-6)
    return iou_score

num_epochs = 10

model.train()

for epoch in range(num_epochs):
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        # labels = labels.squeeze(1)
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        # print(inputs.shape)
        
        outputs = model(inputs)
        
        # print(outputs.shape)
        
        loss = criterion(outputs, labels)
        dice_score = dice_coefficient(labels, outputs)
        iou_score = iou(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_dice += dice_score
        running_iou += iou_score

    epoch_loss = running_loss / len(train_loader)
    epoch_dice = running_dice / len(train_loader)
    epoch_iou = running_iou / len(train_loader)

    print(f"Epoch {epoch + 1}, Loss: {epoch_loss:.3f}, Dice Score: {epoch_dice:.3f}, IoU Score: {epoch_iou:.3f}")

print("Finished Training")

torch.save(model.state_dict(), "resnet50_unet_pretrained2.pth")
print("\nModel saved to model.pth")

In [None]:
imgt,maskt = train_dataset[167]

In [None]:
imgp = imgt

maskt.size()

In [None]:
model.eval()

batch_size = 5
imgt = imgt.unsqueeze(0)
print(imgt.shape)

with torch.no_grad():
    output = model(imgt)
    

In [None]:
print((output.size()))

In [None]:
print(output)

In [None]:
predicted_mask = (output.squeeze()>0.5).float().cpu().numpy()
predicted_mask.shape

In [None]:
predicted_mask

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(imgt.permute(1, 2, 0).cpu().numpy())  # Transpose dimensions and convert to numpy array
axs[0].set_title('Input Image')

print(maskt.shape)
axs[1].imshow(maskt.permute(1, 2, 0).cpu().numpy())
axs[1].set_title('Mask')


# print(predicted_mask.shape)

axs[2].imshow(maskt.permute(1, 2, 0).cpu().numpy())
axs[2].set_title('Predicted Mask')

# Remove axis ticks
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])

plt.show()