In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.swin_transformer import SwinTransformer

# Shared Spatial Attention Module
def shared_spatial_attention(x):
    B, C, H, W = x.shape
    q = F.normalize(x, p=2, dim=1)
    k = F.normalize(x, p=2, dim=1)
    v = x.view(B, C, -1)
    attn = torch.matmul(q.view(B, C, -1), k.view(B, C, -1).transpose(-2, -1))
    return torch.matmul(attn, v).view(B, C, H, W)

# Shared Channel Attention Module
def shared_channel_attention(x):
    B, C, H, W = x.shape
    q = F.adaptive_avg_pool2d(x, 1).view(B, C, 1, 1)
    k = F.adaptive_max_pool2d(x, 1).view(B, C, 1, 1)
    attn = torch.sigmoid(q + k)
    return x * attn

# Aggregation Feature Module
class AggregationFeature(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

# Swin Transformer Encoder-Decoder Architecture for Akshu Region Dataset
class SwinSegmentation(nn.Module):
    def __init__(self, img_size=512, num_classes=9, in_channels=4):  # Adjusted for multispectral data
        super().__init__()
        self.encoder = SwinTransformer(img_size=img_size, in_chans=in_channels, num_classes=0, pretrained=False)
        self.conv1x1 = nn.Conv2d(768, 256, kernel_size=1)  # Adjust channels from Swin Transformer output
        self.decoder = nn.ModuleList([
            AggregationFeature(256, 128),
            AggregationFeature(128, 64),
            AggregationFeature(64, 32)
        ])
        self.final_seg = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        enc_features = self.encoder.forward_features(x)  # Output shape: (B, H/32, W/32, C)
        enc_features = enc_features.permute(0, 3, 1, 2)  # Reshape to (B, C, H, W)
        enc_features = self.conv1x1(enc_features)  # Reduce channels
        
        for af in self.decoder:
            enc_features = af(enc_features)
        
        combined = shared_spatial_attention(enc_features) + shared_channel_attention(enc_features)
        segmentation_output = self.final_seg(combined)
        return segmentation_output

# Example Usage
model = SwinSegmentation(num_classes=9, in_channels=4)  # Adjusted for Akshu region dataset
input_tensor = torch.randn(1, 4, 512, 512)  # 4-channel multispectral input
output = model(input_tensor)
print("Output Shape:", output.shape)  # Expected: (1, num_classes, 512, 512)
print(output)
print(type(output))


Output Shape: torch.Size([1, 9, 16, 16])
tensor([[[[-2.1036, -2.1275, -2.1643,  ..., -2.1359, -2.1392, -2.0001],
          [-2.0955, -2.1237, -2.1378,  ..., -2.1453, -2.1697, -2.0552],
          [-2.1051, -2.1198, -2.1335,  ..., -2.1584, -2.1378, -2.0366],
          ...,
          [-2.1277, -2.1140, -2.1398,  ..., -2.1097, -2.1495, -2.0509],
          [-2.1127, -2.1120, -2.1354,  ..., -2.1099, -2.1570, -2.0256],
          [-2.0397, -2.0293, -2.0287,  ..., -2.0031, -1.9899, -1.9635]],

         [[ 3.6570,  3.7035,  3.7689,  ...,  3.7192,  3.7249,  3.4715],
          [ 3.6399,  3.6970,  3.7214,  ...,  3.7339,  3.7779,  3.5707],
          [ 3.6585,  3.6890,  3.7125,  ...,  3.7575,  3.7208,  3.5384],
          ...,
          [ 3.6988,  3.6798,  3.7236,  ...,  3.6703,  3.7418,  3.5634],
          [ 3.6719,  3.6741,  3.7162,  ...,  3.6717,  3.7562,  3.5184],
          [ 3.5344,  3.5200,  3.5184,  ...,  3.4723,  3.4493,  3.4021]],

         [[ 9.8113,  9.9230, 10.0921,  ...,  9.9611,  9.9764,

In [15]:
!pip install rasterio
!pip install tifffile


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting tifffile
  Downloading tifffile-2025.2.18-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.4/226.4 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tifffile
Successfully installed tifffile-2025.2.18
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.swin_transformer import SwinTransformer
from torchvision import transforms
import cv2
import numpy as np
import tifffile as tiff

# Load and preprocess satellite image
def load_satellite_image(image_path):
    image = tiff.imread(image_path)  # Load TIFF image
    if image.shape[0] > 3:  # If more than 3 channels, take only the first 3 (RGB)
        image = image[:3, :, :]
    image = np.transpose(image, (1, 2, 0))  # Convert to (H, W, C) format
    image = cv2.resize(image, (512, 512))  # Resize
    image = image.astype(np.float32) / 255.0  # Normalize
    image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)  # Convert to (B, C, H, W)
    
    return image

# Shared Spatial Attention Module
def shared_spatial_attention(x):
    B, C, H, W = x.shape
    q = F.normalize(x, p=2, dim=1)
    k = F.normalize(x, p=2, dim=1)
    v = x.view(B, C, -1)
    attn = torch.matmul(q.view(B, C, -1), k.view(B, C, -1).transpose(-2, -1))
    return torch.matmul(attn, v).view(B, C, H, W)

# Shared Channel Attention Module
def shared_channel_attention(x):
    B, C, H, W = x.shape
    q = F.adaptive_avg_pool2d(x, 1).view(B, C, 1, 1)
    k = F.adaptive_max_pool2d(x, 1).view(B, C, 1, 1)
    attn = torch.sigmoid(q + k)
    return x * attn

# Aggregation Feature Module
class AggregationFeature(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

# Swin Transformer Encoder-Decoder Architecture for Akshu Region Dataset
class SwinSegmentation(nn.Module):
    def __init__(self, img_size=512, num_classes=5, in_channels=3):  # Adjusted for RGB satellite data
        super().__init__()
        self.encoder = SwinTransformer(img_size=img_size, in_chans=in_channels, num_classes=0, pretrained=False)
        self.conv1x1 = nn.Conv2d(768, 256, kernel_size=1)  # Adjust channels from Swin Transformer output
        self.decoder = nn.ModuleList([
            AggregationFeature(256, 128),
            AggregationFeature(128, 64),
            AggregationFeature(64, 32)
        ])
        self.final_seg = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        enc_features = self.encoder.forward_features(x)  # Output shape: (B, H/32, W/32, C)
        enc_features = enc_features.permute(0, 3, 1, 2)  # Reshape to (B, C, H, W)
        enc_features = self.conv1x1(enc_features)  # Reduce channels
        
        for af in self.decoder:
            enc_features = af(enc_features)
        
        combined = shared_spatial_attention(enc_features) + shared_channel_attention(enc_features)
        segmentation_output = self.final_seg(combined)
        return segmentation_output

# Load an example satellite image and perform segmentation
image_path = "/kaggle/input/akshu-dataset/Aksu/Test/Image/true_color_image_02_07.tif"  # Replace with actual image path
input_image = load_satellite_image(image_path)

# Initialize model and perform inference
model = SwinSegmentation(num_classes=5, in_channels=3)  # Adjusted for RGB satellite image
output = model(input_image)
print("Output Shape:", output.shape)  # Expected: (1, num_classes, 512, 512)


Output Shape: torch.Size([1, 5, 16, 16])
