In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        # Encoder
        self.encoder1 = self.conv_block(in_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)

        # Decoder
        self.decoder4 = self.upconv_block(1024, 512)
        self.decoder3 = self.upconv_block(512, 256)
        self.decoder2 = self.upconv_block(256, 128)
        self.decoder1 = self.upconv_block(128, 64)

        # Final layer
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2))
        enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2))
        enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2))

        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2))

        # Decoder path
        dec4 = self.decoder4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)  # Skip connection
        dec3 = self.decoder3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)  # Skip connection
        dec2 = self.decoder2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)  # Skip connection
        dec1 = self.decoder1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)  # Skip connection

        return self.final(dec1)


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# Assuming you have a Dataset class defined to load your MRI images and masks
class CustomDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)  # Load image
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)  # Load mask

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

# Define hyperparameters
num_epochs = 20
batch_size = 8
learning_rate = 0.001

# Create DataLoader
train_dataset = CustomDataset(train_image_paths, train_mask_paths, transform=get_augmentation_pipeline())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Instantiate model, loss function, and optimizer
model = UNet(in_channels=1, out_channels=1)  # Adjust in_channels and out_channels as needed
criterion = nn.BCEWithLogitsLoss()  # Use BCE Loss for binary segmentation
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for images, masks in train_loader:
        optimizer.zero_grad()  # Reset gradients
        outputs = model(images.unsqueeze(1).float())  # Add channel dimension
        loss = criterion(outputs, masks.unsqueeze(1).float())  # Add channel dimension
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights
        running_loss += loss.item()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')


In [None]:
def evaluate_model(model, val_loader):
    model.eval()  # Set model to evaluation mode
    dice_scores = []

    with torch.no_grad():  # No need to compute gradients
        for images, masks in val_loader:
            outputs = model(images.unsqueeze(1).float())
            preds = torch.sigmoid(outputs) > 0.5  # Apply sigmoid and threshold
            dice_score = compute_dice_coefficient(preds, masks.unsqueeze(1).float())
            dice_scores.append(dice_score)

    mean_dice_score = np.mean(dice_scores)
    print(f'Mean DICE Score: {mean_dice_score:.4f}')

def compute_dice_coefficient(preds, targets):
    smooth = 1e-6  # To avoid division by zero
    preds_flat = preds.view(-1)
    targets_flat = targets.view(-1)
    intersection = (preds_flat * targets_flat).sum()
    return (2. * intersection + smooth) / (preds_flat.sum() + targets_flat.sum() + smooth)
