In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import models
import torch.nn.functional as F


In [2]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform_image=None, transform_mask=None):
        self.root_dir = root_dir
        self.transform_image = transform_image
        self.transform_mask = transform_mask
        self.images_dir = os.path.join(root_dir, 'imgs')
        self.masks_dir = os.path.join(root_dir, 'masks')
        self.image_filenames = os.listdir(self.images_dir)
        
        self.image_filenames = [name for name in self.image_filenames if not name.startswith(".ipynb")]
        


    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.image_filenames[idx])
        mask_name = os.path.join(self.masks_dir, self.image_filenames[idx].replace('img_', 'mask_').replace('.jpg', '.png'))
        
        

        image = Image.open(img_name).convert('RGB')
        mask = Image.open(mask_name)#.convert('L')  # Convert to grayscale for single channel mask

        if self.transform_image:
            image = self.transform_image(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)
            
        #mask = mask.squeeze(0).long()  # Ensure mask is in long type for cross-entropy loss
        mask = torch.squeeze(mask)  # Ensure mask is [H, W] without the channel dimension
        

        return image, mask

In [3]:
# Transforms dataset
#transform = transforms.Compose([
#    transforms.Resize((256, 256)),  # Resize image
#    transforms.ToTensor(),
#])

# Define Transforms
transform_image = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

transform_mask = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

In [4]:
# Defines dataset and dataloaders
train_dataset = CustomDataset(root_dir='CAVS/Main_Trail/Train', transform_image=transform_image, transform_mask=transform_mask)
test_dataset = CustomDataset(root_dir='CAVS/Main_Trail/Test', transform_image=transform_image, transform_mask=transform_mask)

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [5]:
# Define ResNet model for semantic segmentation
class ResNetSegmentation(nn.Module):
    def __init__(self, num_classes):
        super(ResNetSegmentation, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        
        # Remove the fully connected layer and global average pooling
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])

        # Add 1x1 convolution layer for the final segmentation map
        self.conv1x1 = nn.Conv2d(2048, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.resnet(x)
        x = self.conv1x1(x)
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)  # Upsample to (256, 256)
        return x

In [6]:
#def pixel_accuracy(output, mask):
#    output = torch.argmax(output, dim=1)  # output: [batch_size, height, width]
#    correct = (output == mask).float()
#    accuracy = correct.sum() / correct.numel()
#    return accuracy.item()

def iou(pred, target, n_classes=3):
    pred = torch.argmax(pred, dim=1)  # pred: [batch_size, height, width]
    ious = []
    pred = pred.view(-1)
    target = target.view(-1)
    for cls in range(n_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds[target_inds]).sum().float().item()
        union = pred_inds.sum().float().item() + target_inds.sum().float().item() - intersection
        if union != 0:
            ious.append(intersection / union)
    return np.mean(ious)

In [7]:
# Initialize model
model = ResNetSegmentation(num_classes=3)  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)



In [8]:
# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Binary Cross Entropy loss for binary segmentation
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20, device='cuda:0'):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        epoch_acc = 0
        epoch_iou = 0
        
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
           # epoch_acc += pixel_accuracy(outputs, masks)
            epoch_iou += iou(outputs, masks)

        epoch_loss /= len(train_loader)
        epoch_acc /= len(train_loader)
        epoch_iou /= len(train_loader)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}, IoU: {epoch_iou:.4f}')

        # Add validation step if val_loader is provided
        if val_loader is not None:
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for images, masks in val_loader:
                    images = images.to(device)
                    masks = masks.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    val_loss += loss.item() * images.size(0)
            val_loss /= len(val_loader.dataset)
            print(f"Validation Loss: {val_loss:.4f}")


In [12]:
# Train the model
train_model(model, train_dataloader, None, criterion, optimizer, num_epochs=20, device=device)

IndexError: The shape of the mask [786432] at index 0 does not match the shape of the indexed tensor [262144] at index 0

In [None]:
# Evaluation loop (example)
model.eval()
total_loss = 0.0
with torch.no_grad():
    for images, masks in test_dataloader:
        images = images.to(device)
        masks = masks.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        total_loss += loss.item() * images.size(0)
    
    avg_loss = total_loss / len(test_dataloader.dataset)
    print(f'Average Test Loss: {avg_loss:.4f}')

In [None]:
# Saves model
torch.save(model.state_dict(), 'resnet_segmentation_model.pth')