In [5]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50

In [7]:
# =============== 1. Preparation of Device, Model, and Transform ===============
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assume the number of output classes during training was 9 (including background)
num_classes = 9

# 1.1 Define the network and load local weights
model = deeplabv3_resnet50(num_classes=num_classes)
# If you saved only the state_dict, use load_state_dict
# If you encounter an error, it may mean the entire model was saved, and you need to use torch.load() instead
model.load_state_dict(
    torch.load("./deeplabv3_finetuned_RS_Sen1Flood11_V1.pth", map_location=device),
    strict=False
)

model.to(device)
model.eval()  # Switch to evaluation mode

# 1.2 Define the image transformation for inference (consistent with training)
transform = transforms.Compose([
    # This is just an example: only converting the image to a Tensor.
    # If you performed Resize/Crop/Normalize during training, ensure consistency here
    transforms.ToTensor()
])

  torch.load("./deeplabv3_finetuned_RS_Sen1Flood11_V1.pth", map_location=device),


In [8]:
# =============== 2. Specify Input and Output Directories ===============
input_dir = "./s2input"  # Folder containing images for inference
#output_dir = "./out"
output_dir = "./s2output"
os.makedirs(output_dir, exist_ok=True)

# =============== 3. Define the Color Palette idx2color Corresponding to Your Training ===============
# Example: 0=background, 1=Bareland, 2=Rangeland, 3=Developed space, 4=Road,
# 5=Tree, 6=Water, 7=Agriculture land, 8=Building
idx2color = np.array([
    [128, 128, 128],  # Background
    [0, 0, 0],  # 1=Bareland
    [255, 255, 255],  # 2=Rangeland
], dtype=np.uint8)

# (Optional) If you need to visualize the legend, you can keep this dictionary:
color_mapping = {
    "-1": [128, 128, 128],
    "0": [0, 0, 0],
    "1": [255, 255, 255],
}

In [9]:
from tqdm import tqdm
# =============== 4. Batch Process All Images in the Specified Folder ===============
valid_exts = {".jpg", ".png", ".jpeg"}

# Get a list of valid image files in the input directory
image_files = [file_name for file_name in os.listdir(input_dir) if os.path.splitext(file_name)[1].lower() in valid_exts]

# Create a progress bar for the list of images
for file_name in tqdm(image_files, desc="Processing Images", unit="image"):
    img_path = os.path.join(input_dir, file_name)
    print(f"Inferencing on: {img_path}")

    # ============== 4.1. Read and Transform ==============
    input_image = Image.open(img_path).convert("RGB")
    input_tensor = transform(input_image).unsqueeze(0).to(device)  # shape: [1, 3, H, W]

    # ============== 4.2. Forward Inference ==============
    with torch.no_grad():
        outputs = model(input_tensor)          # dict: {"out": logits, ...}
        logits = outputs["out"]               # shape: [1, num_classes, H, W]
        pred = logits.argmax(dim=1)           # [1, H, W]
        pred = pred.squeeze(0).cpu().numpy()  # (H, W)

    # ============== 4.3. Map Prediction to a Colored Image ==============
    # Classes outside the defined range are treated as background
    pred[pred >= len(idx2color)] = 0
    pred_color = idx2color[pred]  # (H, W, 3)
    pred_image = Image.fromarray(pred_color)

    # ============== 4.4. Save Inference Results to output_dir ==============
    save_path = os.path.join(output_dir, file_name)
    pred_image.save(save_path)
    print(f"Saved predicted segmentation to: {save_path}")

Processing Images:   0%|          | 0/2 [00:00<?, ?image/s]

Inferencing on: ./s2input\Sentinel2_SWIR_NIR_Red.png


Processing Images: 100%|██████████| 2/2 [00:00<00:00,  4.16image/s]

Saved predicted segmentation to: ./s2output\Sentinel2_SWIR_NIR_Red.png
Inferencing on: ./s2input\Sentinel2_TrueColor_RGB.png
Saved predicted segmentation to: ./s2output\Sentinel2_TrueColor_RGB.png



