In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import models
import torch.nn.functional as F


In [19]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform_image=None, transform_mask=None):
        self.root_dir = root_dir
        self.transform_image = transform_image
        self.transform_mask = transform_mask
        self.images_dir = os.path.join(root_dir, 'imgs')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.image_filenames = os.listdir(self.images_dir)
        
        self.image_filenames = [name for name in self.image_filenames if not name.startswith(".ipynb")]
        


    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_filenames[idx])
        mask_name = os.path.join(self.masks_dir, self.image_filenames[idx].replace('img_', 'mask_').replace('.jpg', '.png'))
        
        

        image = Image.open(img_name).convert('RGB')
        mask = Image.open(mask_name).convert('L')  # Convert to grayscale for single channel mask

        if self.transform_image:
            image = self.transform_image(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)
            
        mask = mask.squeeze(0).long()  # Ensure mask is in long type for cross-entropy loss

        return image, mask

In [20]:
# Transforms dataset
#transform = transforms.Compose([
#    transforms.Resize((256, 256)),  # Resize image
#    transforms.ToTensor(),
#])

# Define Transforms
transform_image = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

transform_mask = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

In [21]:
# Defines dataset and dataloaders
train_dataset = CustomDataset(root_dir='CAVS/Main_Trail/Train', transform_image=transform_image, transform_mask=transform_mask)
test_dataset = CustomDataset(root_dir='CAVS/Main_Trail/Test', transform_image=transform_image, transform_mask=transform_mask)

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [22]:
# Define ResNet model for semantic segmentation
class ResNetSegmentation(nn.Module):
    def __init__(self, num_classes):
        super(ResNetSegmentation, self).__init__()
        self.resnet = models.resnet50(pretrained=True)  # Load pre-trained ResNet
        
          # Remove the fully connected layer and global average pooling
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])

        # Add 1x1 convolution layer for the final segmentation map
        self.conv1x1 = nn.Conv2d(2048, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.resnet(x)
        x = self.conv1x1(x)
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)  # Upsample to (256, 256)
        return x

In [27]:
# Initialize model
model = ResNetSegmentation(num_classes=1)  # Assuming you have binary segmentation (1 channel mask)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [28]:
# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Binary Cross Entropy loss for binary segmentation
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [29]:
# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, masks in train_dataloader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        # Ensure output size matches mask size
        outputs = torch.sigmoid(outputs)  # Apply sigmoid to output for binary segmentation
        outputs = outputs.view(-1,1,256,256)  # Ensure output size matches mask size
        
        # Calculate loss
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(train_dataloader.dataset)
    print(f'Epoch {epoch}/{num_epochs - 1}, Loss: {epoch_loss:.4f}')

ValueError: Target size (torch.Size([4, 1, 256, 256])) must be the same as input size (torch.Size([12, 1, 256, 256]))

In [26]:
# Evaluation loop (example)
model.eval()
total_loss = 0.0
with torch.no_grad():
    for images, masks in test_dataloader:
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        total_loss += loss.item() * images.size(0)
    
    avg_loss = total_loss / len(test_dataloader.dataset)
    print(f'Average Test Loss: {avg_loss:.4f}')

Average Test Loss: 1.3056


In [None]:
# Saves model
torch.save(model.state_dict(), 'resnet_segmentation_model.pth')