# NEW CODE

In [None]:
# CRFNet-RS: Semantic Segmentation for Remote Sensing Images

# Cell 1: Setup Environment
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import random
from tqdm.notebook import tqdm
from glob import glob
import time
from datetime import datetime
from skimage import io

# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Clone the repository
!git clone https://github.com/Ayana-Inria/CRFNet-RS.git
sys.path.append('./CRFNet-RS')

# Install dependencies
!pip install -r ./CRFNet-RS/requirements.txt

In [None]:
# Cell 2: Fix Imports in main.py
# Fix the unet import issue
with open('./CRFNet-RS/main.py', 'r') as file:
    content = file.read()

# Replace the incorrect import
fixed_content = content.replace('from net.unet import *', '# from net.unet import *')

with open('./CRFNet-RS/main.py', 'w') as file:
    file.write(fixed_content)

print("Fixed import in main.py")

In [None]:
# Cell 3: Configure Paths and Parameters
import os
from datetime import datetime

# Configuration parameters
WINDOW_SIZE = (256, 256)  # Patch size
STRIDE = 32  # Stride for testing inference
IN_CHANNELS = 3  # Number of input channels (RGB/IRRG)
BATCH_SIZE = 10  # Mini-batch size
EPOCHS = 50  # Number of training epochs
SAVE_EPOCH = 10  # Save model every N epochs
BASE_LR = 0.01  # Base learning rate
WEIGHT_DECAY = 0.0005  # Weight decay for optimizer

# Dataset parameters
DATASET_TYPE = "Vaihingen"  # Options: "Vaihingen" or "Potsdam"
GT_TYPE = "conncomp"  # Options: "full", "conncomp", "ero"
ERO_DISK_SIZE = 8  # Size of erosion disk for ground truth processing

# Organize folders
DATA_ROOT = "/kaggle/input/potsdamvaihingen/"  # Input data path
OUTPUT_ROOT = "/kaggle/working/"  # Working directory for outputs
WORKING_DATA_ROOT = "/kaggle/working/data"  # Working directory for data processing

# Create necessary directories
os.makedirs(OUTPUT_ROOT, exist_ok=True)
os.makedirs(WORKING_DATA_ROOT, exist_ok=True)
os.makedirs(f"{WORKING_DATA_ROOT}/top", exist_ok=True)
os.makedirs(f"{WORKING_DATA_ROOT}/gt", exist_ok=True)
os.makedirs(f"{WORKING_DATA_ROOT}/gt_eroded", exist_ok=True)

# Experiment naming
EXPERIMENT_NAME = f"CRFNet_{DATASET_TYPE}_{GT_TYPE}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(f"{OUTPUT_ROOT}/{EXPERIMENT_NAME}", exist_ok=True)

# Set paths based on dataset
if DATASET_TYPE == "Vaihingen":
    train_ids = ['1', '3', '23', '26', '7', '11', '13', '28', '17', '32', '34', '37']
    test_ids = ['5', '15', '21', '30']
else:  # Potsdam
    train_ids = ['3_11', '4_11', '5_10', '6_7', '6_8', '6_9', '7_7', '7_8', '7_9', '7_10']
    test_ids = ['3_12', '4_10', '4_12', '5_11', '6_12']

# # Data file paths
# DATA_FILES = f"{DATA_ROOT}/ISPRS_semantic_labeling_Vaihingen/top/top_mosaic_09cm_area{{id}}.tif"
# LABEL_FILES = f"{DATA_ROOT}/ISPRS_semantic_labeling_Vaihingen/gts_for_participants/top_mosaic_09cm_area{{id}}.tif"
# ERODED_FILES = f"{DATA_ROOT}/5_Labels_for_participants_no_Boundary/5_Labels_for_participants_no_Boundary/top_potsdam_{{id}}_label_noBoundary.tif"

# For Vaihingen dataset
DATA_FILES = f"{DATA_ROOT}/ISPRS_semantic_labeling_Vaihingen/top/top_mosaic_09cm_area{{}}.tif"
LABEL_FILES = f"{DATA_ROOT}/ISPRS_semantic_labeling_Vaihingen/gts_for_participants/top_mosaic_09cm_area{{}}.tif"

# Use the original labels as "eroded" labels
ERODED_FILES = LABEL_FILES

# Class labels
LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"]
N_CLASSES = len(LABELS)

In [None]:
# Open the file and fix the indentation
!cat /kaggle/working/CRFNet-RS/utils/utils_network.py | head -90
!sed -i '89s/^/        /' /kaggle/working/CRFNet-RS/utils/utils_network.py
from net.net import CRFNet

In [None]:
# Cell 4: Import Required Modules
from dataset.dataset import ISPRS_dataset
from utils.utils_dataset import convert_from_color, convert_to_color, disk
from utils.utils import metrics, sliding_window, count_sliding_window, grouper
from utils.export_result import set_output_location, export_results
from net.net import CRFNet
from net.loss import CrossEntropy2d
from skimage import io
import torch.optim as optim
from torch.autograd import Variable

# Display dataset information
print(f"Dataset: {DATASET_TYPE}")
print(f"Ground Truth Type: {GT_TYPE}")
print(f"Training on {len(train_ids)} tiles: {train_ids}")
print(f"Testing on {len(test_ids)} tiles: {test_ids}")
print(f"Using {'GPU' if torch.cuda.is_available() else 'CPU'} for computation")

In [None]:
# Cell 5: Initialize Model and Optimizer
# Initialize the CRFNet model
net = CRFNet(n_channels=IN_CHANNELS, n_classes=N_CLASSES, bilinear=True)

# Setup optimizer and learning rate scheduler
optimizer = optim.SGD(net.parameters(), lr=BASE_LR, momentum=0.9, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1)

# Move model to GPU if available
if torch.cuda.is_available():
    net.cuda()
    WEIGHTS = torch.ones(N_CLASSES).cuda()
else:
    WEIGHTS = torch.ones(N_CLASSES)

In [None]:
# Cell 6: Define Training Function
def train_model(net, optimizer, scheduler, train_loader, epochs, save_epoch, weights, output_path):
    """
    Train the CRFNet model
    """
    # Import any missing modules needed for training
    import torch.nn.functional as F
    import cv2
    from utils.utils_network import compute_class_weight
    from net.train import train
    
    # Run the training function
    train(net, optimizer, epochs, save_epoch, weights, train_loader, BATCH_SIZE, WINDOW_SIZE, output_path, scheduler)
    
    # Save final model
    final_model_path = f'{output_path}/model_final.pth'
    return final_model_path

In [None]:
# Cell 7: Define Testing Function
def test_model(net, test_ids, data_files, label_files, eroded_files, labels, stride, batch_size, window_size, output_path=None):
    """
    Test the model on the provided test data
    """
    from net.test_network import test
    
    # Load test data
    test_images = [1/255 * np.asarray(io.imread(DATA_FILES.format(id)), dtype='float32') for id in test_ids]
    test_labels = [np.asarray(io.imread(LABEL_FILES.format(id)), dtype='uint8') for id in test_ids]
    eroded_labels = [convert_from_color(label) for label in test_labels]
    # test_images = [1/255 * np.asarray(io.imread(data_files.format(id)), dtype='float32') for id in test_ids]
    # test_labels = [np.asarray(io.imread(label_files.format(id)), dtype='uint8') for id in test_ids]
    # eroded_labels = [convert_from_color(io.imread(eroded_files.format(id))) for id in test_ids]
    
    # Run the test
    acc, all_preds, all_gts = test(
        net, test_ids, test_images, test_labels, eroded_labels, 
        labels, stride, batch_size, window_size=window_size, all=True
    )
    
    # Export results
    if output_path:
        title = "Quantitative results for CRFNet testing"
        export_results(
            all_preds, all_gts, 
            os.path.dirname(output_path), os.path.basename(output_path),
            confusionMat=True,
            prodAccuracy=True,
            averageAccuracy=True,
            kappaCoeff=True,
            title=title
        )
        
        # Save prediction images
        for pred, tile_id in zip(all_preds, test_ids):
            img = convert_to_color(pred)
            io.imsave(f"{output_path}/segmentation_result_area{tile_id}.png", img)
    
    return acc, all_preds, all_gts

In [None]:
# Modify the data file paths to use .format(id=...) instead
DATA_FILES = f"{DATA_ROOT}/ISPRS_semantic_labeling_Vaihingen/top/top_mosaic_09cm_area{{}}.tif"
LABEL_FILES = f"{DATA_ROOT}/ISPRS_semantic_labeling_Vaihingen/gts_for_participants/top_mosaic_09cm_area{{}}.tif"
ERODED_FILES = f"{DATA_ROOT}/5_Labels_for_participants_no_Boundary/5_Labels_for_participants_no_Boundary/top_potsdam_{{}}_label_noBoundary.tif"

# Then in the dataset initialization
train_set = ISPRS_dataset(
    ids=train_ids,
    ids_type='TRAIN',
    gt_type=GT_TYPE,
    gt_modification=disk(ERO_DISK_SIZE),
    data_files=DATA_FILES,
    label_files=LABEL_FILES,
    window_size=WINDOW_SIZE,
    cache=True,
    augmentation=True
)

# Create data loader
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE)
print(f"Created training loader with approximately {len(train_set) // BATCH_SIZE} batches per epoch")

In [None]:
# def test_model(net, test_ids, data_files, label_files, eroded_files, labels, stride, batch_size, window_size, output_path=None):
#     """
#     Test the model on the provided test data
#     """
#     from net.test_network import test
#     from skimage import io
#     import numpy as np
#     import os
#     from utils.utils_dataset import convert_from_color
#     from utils.export_result import export_results
    
#     all_preds = []
#     all_gts = []
    
#     # Load test data safely with error handling
#     test_images = []
#     test_labels = []
#     eroded_labels = []
    
#     for id in test_ids:
#         print(f"Loading test data for tile {id}...")
        
#         # Load input image
#         try:
#             img = 1/255 * np.asarray(io.imread(data_files.format(id)), dtype='float32')
#             test_images.append(img)
#             print(f"Image shape: {img.shape}")
#         except Exception as e:
#             print(f"Error loading image for tile {id}: {e}")
#             continue
        
#         # Load label
#         try:
#             label = np.asarray(io.imread(label_files.format(id)), dtype='uint8')
#             test_labels.append(label)
#             print(f"Label shape: {label.shape}")
#         except Exception as e:
#             print(f"Error loading label for tile {id}: {e}")
#             # Remove the corresponding image
#             test_images.pop()
#             continue
        
#         # Load eroded label
#         try:
#             eroded = io.imread(eroded_files.format(id))
#             print(f"Eroded label initial shape: {eroded.shape}")
            
#             # Check if the image is already a 2D array (grayscale)
#             if len(eroded.shape) == 2:
#                 # This is already a label map, no need to convert from color
#                 eroded_label = eroded
#             else:
#                 # This is an RGB image, convert from color
#                 eroded_label = convert_from_color(eroded)
                
#             eroded_labels.append(eroded_label)
#             print(f"Processed eroded label shape: {eroded_label.shape}")
#         except Exception as e:
#             print(f"Error loading eroded label for tile {id}: {e}")
#             # Remove the corresponding image and label
#             test_images.pop()
#             test_labels.pop()
#             continue
    
#     # Make sure we have data to test
#     if not test_images:
#         print("No valid test data found. Check your file paths and image formats.")
#         return 0, [], []
    
#     # Update test_ids to only include those we successfully loaded
#     valid_test_ids = test_ids[:len(test_images)]
#     print(f"Testing on {len(valid_test_ids)} valid tiles: {valid_test_ids}")
    
#     # Run the test
#     acc, all_preds, all_gts = test(
#         net, valid_test_ids, test_images, test_labels, eroded_labels, 
#         labels, stride, batch_size, window_size=window_size, all=True
#     )
    
#     # Export results
#     if output_path and all_preds:
#         title = "Quantitative results for CRFNet testing"
#         export_results(
#             all_preds, all_gts, 
#             os.path.dirname(output_path), os.path.basename(output_path),
#             confusionMat=True,
#             prodAccuracy=True,
#             averageAccuracy=True,
#             kappaCoeff=True,
#             title=title
#         )
        
#         # Save prediction images
#         for pred, tile_id in zip(all_preds, valid_test_ids):
#             img = convert_to_color(pred)
#             save_path = f"{output_path}/segmentation_result_area{tile_id}.png"
#             io.imsave(save_path, img)
#             print(f"Saved prediction to {save_path}")
    
#     return acc, all_preds, all_gts

In [None]:
# Cell 9: Training (Optional)
# Set TRAIN_MODEL to True to train, False to skip training
TRAIN_MODEL = True  # Change to True to train the model

if TRAIN_MODEL:
    print("Starting model training...")
    model_path = train_model(
        net=net,
        optimizer=optimizer,
        scheduler=scheduler,
        train_loader=train_loader,
        epochs=EPOCHS,
        save_epoch=SAVE_EPOCH,
        weights=WEIGHTS,
        output_path=f"{OUTPUT_ROOT}/{EXPERIMENT_NAME}"
    )
    print(f"Training completed! Model saved to {model_path}")
else:
    print("Skipping model training.")
    # Specify a pre-trained model path here if needed
    # model_path = 'path/to/pretrained/model.pth'
    # net.load_state_dict(torch.load(model_path))

In [None]:
# Cell 10: Testing
# Set TEST_MODEL to True to test the model
TEST_MODEL = True

if TEST_MODEL:
    print("Starting model testing...")
    
    # Test the model
    accuracy, all_preds, all_gts = test_model(
        net=net,
        test_ids=test_ids,
        data_files=DATA_FILES,
        label_files=LABEL_FILES,
        eroded_files=ERODED_FILES,
        labels=LABELS,
        stride=STRIDE,
        batch_size=BATCH_SIZE,
        window_size=WINDOW_SIZE,
        output_path=f"{OUTPUT_ROOT}/{EXPERIMENT_NAME}"
    )
    
    print(f"Testing completed with overall accuracy: {accuracy:.2f}%")

In [None]:
# Cell 11: Visualization
if TEST_MODEL and 'all_preds' in locals():
    print("Generating visualization of results...")
    
    # Create a visualization of all test results
    n_images = len(test_ids)
    fig, axes = plt.subplots(n_images, 3, figsize=(15, 5*n_images))
    
    for i, (id, pred, gt) in enumerate(zip(test_ids, all_preds, all_gts)):
        # Load original image
        img = 1/255 * np.asarray(io.imread(DATA_FILES.format(id)), dtype='float32')
        
        # Display original image, ground truth, and prediction
        if n_images > 1:
            axes[i, 0].imshow(np.asarray(255 * img, dtype='uint8'))
            axes[i, 0].set_title(f'Area {id} - Original')
            axes[i, 1].imshow(convert_to_color(gt))
            axes[i, 1].set_title('Ground Truth')
            axes[i, 2].imshow(convert_to_color(pred))
            axes[i, 2].set_title('Prediction')
        else:
            axes[0].imshow(np.asarray(255 * img, dtype='uint8'))
            axes[0].set_title(f'Area {id} - Original')
            axes[1].imshow(convert_to_color(gt))
            axes[1].set_title('Ground Truth')
            axes[2].imshow(convert_to_color(pred))
            axes[2].set_title('Prediction')
    
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_ROOT}/{EXPERIMENT_NAME}/all_results.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Results visualization saved to {OUTPUT_ROOT}/{EXPERIMENT_NAME}/all_results.png")

print("Pipeline completed successfully!")