In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [13]:
import pandas as pd
import numpy as np
from pathlib import Path
import shutil
import os
import re

def generate_label_mapping(root_dir, other_dir, input_subdir, output_csv):
    """
    Generate a CSV mapping input chips to corresponding segmentation maps.

    Args:
        root_dir (str or Path): Root directory containing the subdirectories for chips and segmentation maps.
        input_subdir (str): Subdirectory path for chips within the root directory.
        output_csv (str or Path): Output path for the generated CSV file.
    """
    root_dir = Path(root_dir)
    chips_orig = os.listdir(root_dir / input_subdir / "chips")

    chips = [chip.replace("chip", f"{input_subdir}/chips/chip") for chip in chips_orig]
    seg_maps = [chip.replace("chip", f"{input_subdir}/seg_maps/seg_map") for chip in chips_orig]

    df = pd.DataFrame({"Input": chips, "Label": seg_maps})
    df.to_csv(other_dir + '/' + output_csv, index=False)
    
    print(f"Number of rows is: {df.shape[0]}")
    print(f"CSV generated and saved to: {root_dir / output_csv}")
    

In [15]:
generate_label_mapping('/kaggle/input/geo-ai-hack/', '/kaggle/working/',"s2_train/s2_train/", "s2_train_ds.csv")
generate_label_mapping('/kaggle/input/geo-ai-hack/', '/kaggle/working/', "s2_test/s2_test/", "s2_test_ds.csv")

Number of rows is: 11764
CSV generated and saved to: /kaggle/input/geo-ai-hack/s2_train_ds.csv
Number of rows is: 3937
CSV generated and saved to: /kaggle/input/geo-ai-hack/s2_test_ds.csv


In [None]:
!pip install rasterio tqdm

import os
import torch
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models.segmentation import deeplabv3_resnet50
import torch.optim as optim
import torch.nn as nn
import pandas as pd
from tqdm import tqdm  # Import tqdm for the progress bar

# Custom Dataset Class for 8-Channel Segmentation
class CustomSegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths=None, transform=None, is_test=False):
        self.image_paths = image_paths
        self.mask_paths = mask_paths if not is_test else None
        self.transform = transform
        self.is_test = is_test
        self.samples = []

        if not is_test:
            for img_path, mask_path in zip(image_paths, mask_paths):
                if os.path.exists(img_path) and os.path.exists(mask_path):
                    self.samples.append((img_path, mask_path))
        else:
            for img_path in image_paths:
                if os.path.exists(img_path):
                    self.samples.append((img_path, None))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, mask_path = self.samples[idx]
        preprocessing = Preprocessing(image_path)
        image = preprocessing.preprocess_image()
        image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)  # (C, H, W)

        if self.is_test:
            return image, {"image_id": torch.tensor([idx])}
        else:
            mask = plt.imread(mask_path)

            # Convert mask to grayscale (if it has extra channels)
            if mask.ndim == 3:
                mask = mask[..., 0]  # Take only the first channel
                
            # Create a writable copy of the mask
            mask = mask.copy()  # Make the mask writable

            # Convert mask values to integers (0 for background, 1 for class)
            mask[mask == -9999] = 0  # Set -9999 as background
            mask[mask == 1] = 1       # Set 1 as the foreground class

            # Ensure it's a 2D tensor (H, W) and correct dtype
            mask = torch.tensor(mask, dtype=torch.long).squeeze(0)

            if self.transform:
                image = self.transform(image)
                
            print(mask.shape)  # Check the shape of the mask

            return image, mask  # Image (C, H, W), Mask (H, W)

# Preprocessing Class for 8 Channels
class Preprocessing:
    def __init__(self, image_path):
        self.image_path = image_path

    def load_bands(self):
        with rasterio.open(self.image_path) as src:
            blue = src.read(1)
            green = src.read(2)
            red = src.read(3)
            nir = src.read(4)
            swir1 = src.read(5)
            swir2 = src.read(6)
        return blue, green, red, nir, swir1, swir2

    def preprocess_image(self):
        blue, green, red, nir, swir1, swir2 = self.load_bands()
        ndvi = self.compute_ndvi(red, nir)
        evi = self.compute_evi(nir, red, blue)
        normalized_bands = [self.normalize_band(band) for band in [blue, green, red, nir, swir1, swir2]]
        image = np.stack(normalized_bands + [ndvi, evi], axis=-1)  # Stack 8 channels
        return image

    def normalize_band(self, band):
        return (band - np.min(band)) / (np.max(band) - np.min(band))

    def compute_ndvi(self, red, nir):
        return (nir - red) / (nir + red + 1e-6)

    def compute_evi(self, nir, red, blue, g=2.5, c1=6, c2=7.5, l=1):
        return np.clip(g * (nir - red) / (nir + c1 * red - c2 * blue + l), 0, 1)

# Training and Model Setup
class DeepLabV3Model:
    def __init__(self, num_classes=2, device='cuda'):
        self.device = device
        self.model = deeplabv3_resnet50(pretrained=True)

        # Modify input layer to accept 8 channels
        in_features = self.model.backbone.conv1.in_channels
        self.model.backbone.conv1 = nn.Conv2d(8, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Modify output layer for segmentation classes
        self.model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)

        self.model.to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

    def train(self, dataloader, num_epochs=10, checkpoint_interval=5):
        self.model.train()
        for epoch in range(num_epochs):
            running_loss = 0.0
            # Initialize tqdm for progress bar
            pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

            for images, masks in pbar:
                images, masks = images.to(self.device), masks.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)['out']
                loss = self.criterion(outputs, masks)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()

                # Update the progress bar description with the current loss
                pbar.set_postfix({"Loss": running_loss / (pbar.n + 1)})

            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(dataloader):.4f}")
            
            # Save checkpoint every `checkpoint_interval` epochs
            if (epoch + 1) % checkpoint_interval == 0:
                self.save_checkpoint(epoch + 1)

    def evaluate(self, dataloader):
        self.model.eval()
        iou_scores = []
        with torch.no_grad():
            for images, masks in dataloader:
                images, masks = images.to(self.device), masks.to(self.device)
                outputs = self.model(images)['out']
                preds = torch.argmax(outputs, dim=1)
                intersection = (preds & masks).float().sum()
                union = (preds | masks).float().sum()
                iou_scores.append(intersection / union)
        mean_iou = sum(iou_scores) / len(iou_scores)
        print(f"Mean IoU: {mean_iou:.4f}")

    def save_checkpoint(self, epoch):
        checkpoint_path = f"checkpoint_epoch_{epoch}.pth"
        torch.save(self.model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

# Load dataset
train_csv = pd.read_csv("/kaggle/working/s2_train_ds.csv")
test_csv = pd.read_csv("/kaggle/working/s2_test_ds.csv")

train_image_paths = train_csv["Input"].tolist()
train_mask_paths = train_csv["Label"].tolist()
test_image_paths = test_csv["Input"].tolist()

# Data Transformations
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
])
%cd /kaggle/input/geo-ai-hack

# Create Datasets and Dataloaders
train_dataset = CustomSegmentationDataset(image_paths=train_image_paths, mask_paths=train_mask_paths, transform=transform)
test_dataset = CustomSegmentationDataset(image_paths=test_image_paths, is_test=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Initialize and train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepLabV3Model(device=device)
model.train(train_dataloader, num_epochs=10, checkpoint_interval=5)

# Evaluate the model
model.evaluate(test_dataloader)

Collecting rasterio
  Downloading rasterio-1.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Downloading rasterio-1.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m56.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading affine-2.4.0-py3-none-any.whl (15 kB)
Installing collected packages: affine, rasterio
Successfully installed affine-2.4.0 rasterio-1.4.3
/kaggle/input/geo-ai-hack


Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth
100%|██████████| 161M/161M [00:01<00:00, 161MB/s]  
Epoch 1/10:   0%|          | 0/1471 [00:00<?, ?it/s]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


  return np.clip(g * (nir - red) / (nir + c1 * red - c2 * blue + l), 0, 1)


torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 1/1471 [00:23<9:43:20, 23.81s/it, Loss=0.669]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 2/1471 [00:46<9:22:54, 22.99s/it, Loss=0.639]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 3/1471 [01:08<9:09:49, 22.47s/it, Loss=0.601]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 4/1471 [01:29<8:58:19, 22.02s/it, Loss=0.566]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 5/1471 [01:51<8:58:17, 22.03s/it, Loss=0.531]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 6/1471 [02:12<8:52:05, 21.79s/it, Loss=0.503]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 7/1471 [02:35<8:55:35, 21.95s/it, Loss=0.47] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 8/1471 [02:56<8:51:49, 21.81s/it, Loss=0.441]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 9/1471 [03:17<8:48:14, 21.68s/it, Loss=0.422]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 10/1471 [03:39<8:45:11, 21.57s/it, Loss=0.397]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


  return (band - np.min(band)) / (np.max(band) - np.min(band))


torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 11/1471 [04:01<8:47:12, 21.67s/it, Loss=nan]  

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 12/1471 [04:22<8:45:07, 21.60s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 13/1471 [04:43<8:42:30, 21.50s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 14/1471 [05:05<8:40:50, 21.45s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 15/1471 [05:26<8:39:01, 21.39s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 16/1471 [05:47<8:35:45, 21.27s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 17/1471 [06:08<8:35:41, 21.28s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 18/1471 [06:30<8:36:47, 21.34s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|▏         | 19/1471 [06:51<8:34:27, 21.26s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|▏         | 20/1471 [07:13<8:38:50, 21.45s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|▏         | 21/1471 [07:34<8:40:42, 21.55s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|▏         | 22/1471 [07:56<8:36:45, 21.40s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 23/1471 [08:17<8:36:14, 21.39s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 24/1471 [08:38<8:35:23, 21.37s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 25/1471 [09:02<8:49:18, 21.96s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 26/1471 [09:25<8:59:54, 22.42s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 27/1471 [09:48<9:05:33, 22.67s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 28/1471 [10:11<9:05:17, 22.67s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 29/1471 [10:34<9:05:55, 22.72s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 30/1471 [10:57<9:09:02, 22.86s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 31/1471 [11:18<8:57:57, 22.42s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 32/1471 [11:39<8:44:37, 21.87s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 33/1471 [12:00<8:39:21, 21.67s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 34/1471 [12:21<8:35:55, 21.54s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 35/1471 [12:42<8:26:57, 21.18s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 36/1471 [13:03<8:26:00, 21.16s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 37/1471 [13:24<8:26:09, 21.18s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 38/1471 [13:45<8:22:47, 21.05s/it, Loss=nan]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
