In [4]:
! kaggle datasets download -d jahidhasan66/isprs-potsdam

Dataset URL: https://www.kaggle.com/datasets/jahidhasan66/isprs-potsdam
License(s): unknown
Downloading isprs-potsdam.zip to /content
 99% 361M/366M [00:08<00:00, 41.3MB/s]
100% 366M/366M [00:08<00:00, 46.9MB/s]


In [5]:
! unzip isprs-potsdam.zip


Archive:  isprs-potsdam.zip
  inflating: patches/Images/Image_0.tif  
  inflating: patches/Images/Image_1.tif  
  inflating: patches/Images/Image_10.tif  
  inflating: patches/Images/Image_100.tif  
  inflating: patches/Images/Image_1000.tif  
  inflating: patches/Images/Image_1001.tif  
  inflating: patches/Images/Image_1002.tif  
  inflating: patches/Images/Image_1003.tif  
  inflating: patches/Images/Image_1004.tif  
  inflating: patches/Images/Image_1005.tif  
  inflating: patches/Images/Image_1006.tif  
  inflating: patches/Images/Image_1007.tif  
  inflating: patches/Images/Image_1008.tif  
  inflating: patches/Images/Image_1009.tif  
  inflating: patches/Images/Image_101.tif  
  inflating: patches/Images/Image_1010.tif  
  inflating: patches/Images/Image_1011.tif  
  inflating: patches/Images/Image_1012.tif  
  inflating: patches/Images/Image_1013.tif  
  inflating: patches/Images/Image_1014.tif  
  inflating: patches/Images/Image_1015.tif  
  inflating: patches/Images/Image_101

In [6]:
! pip install -q kaggle

In [7]:
! ls


isprs-potsdam.zip  patches  sample_data


In [8]:
!unzip isprs-potsdam.zip -d ./isprs_potsdam

Archive:  isprs-potsdam.zip
  inflating: ./isprs_potsdam/patches/Images/Image_0.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_10.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_100.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1000.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1001.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1002.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1003.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1004.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1005.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1006.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1007.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1008.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1009.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_101.tif  
  inflating: ./isprs_potsdam/patches/Images/Image_1010.tif  
  infl

In [9]:
! ls ./isprs_potsdam/patches

Images	Labels


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PartialCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(PartialCrossEntropyLoss, self).__init__()

    def forward(self, predictions, targets, mask):
        # Compute standard cross-entropy loss
        ce_loss = F.cross_entropy(predictions, targets, reduction='none')

        # Focal loss application
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** 2 * ce_loss

        # Apply the mask to focus only on labeled points
        masked_loss = focal_loss * mask

        # Normalize by the sum of mask values
        loss = masked_loss.sum() / mask.sum()
        return loss


In [22]:
import torch
from torchvision import models,transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import random

# Define transformations for images and masks
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

target_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Custom dataset class
class ISPRSPotsdamDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_transform=None, target_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_transform = image_transform
        self.target_transform = target_transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_name = self.images[idx].replace('Image', 'Label')  # Adjust naming convention
        mask_path = os.path.join(self.mask_dir, mask_name)
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.image_transform:
            image = self.image_transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        return image, mask

# Load dataset with the specified transformations
dataset = ISPRSPotsdamDataset(
    image_dir='./isprs_potsdam/patches/Images',
    mask_dir='./isprs_potsdam/patches/Labels',
    image_transform=image_transform,
    target_transform=target_transform
)

# DataLoader to load the data in batches
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)




# Function to create sparse point labels
def create_point_labels(targets, num_points=10):
    # Initialize an empty mask with the same shape as targets
    mask = torch.zeros_like(targets, dtype=torch.float32)

    for i in range(targets.shape[0]):  # Loop over the batch
        points_chosen = 0
        while points_chosen < num_points:
            # Randomly choose a pixel
            y, x = random.randint(0, targets.shape[1] - 1), random.randint(0, targets.shape[2] - 1)

            # Ensure the chosen pixel is not background (class label 255)
            if targets[i, y, x].max().item() != 255:  # Adjust index if mask has multiple channels
                mask[i, y, x] = 1
                points_chosen += 1

    return mask

# Example: Generating a mask for a batch of targets
for images, targets in dataloader:
    mask = create_point_labels(targets)
    break  # Run only for the first batch for demonstration

In [23]:
class SimpleSegmentationNet(nn.Module):
    def __init__(self):
        super(SimpleSegmentationNet, self).__init__()
        backbone = models.resnet18(pretrained=True)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        self.conv = nn.Conv2d(512, 6, kernel_size=1)  # Adjust output channels as needed for your dataset

    def forward(self, x):
        x = self.backbone(x)
        x = self.conv(x)
        return x

In [25]:
import torch.optim as optim

# --- Step 3: Model, Loss, and Optimizer ---
model = SimpleSegmentationNet()

criterion = PartialCrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# --- Step 4: Training Loop ---
num_epochs = 2
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, targets in dataloader:
        targets = targets.squeeze(1)  # Ensure targets are the correct shape
        # print(targets.shape)
        mask = create_point_labels(targets)

        outputs = model(images)
        loss = criterion(outputs, targets.long(), mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()


    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss / len(dataloader):.4f}')

RuntimeError: size mismatch (got input: [4, 6, 8, 8] , target: [4, 256, 256]