In [None]:
!rm -rf cvae-quadratreeMRF # Remove old version if exists
!git clone https://github.com/realjules/cvae-quadratreeMRF.git
import sys
sys.path.append('/kaggle/working/cvae-quadratreeMRF')

In [None]:
# Import required libraries
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from skimage.morphology import disk
from glob import glob
import random
import os
from tqdm.notebook import tqdm
import cv2

# Set random seed for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [None]:
# Import project modules
from dataset.dataset import ISPRS_dataset
from net.net import CRFNet
from net.loss import CrossEntropy2d
from utils.utils_dataset import convert_to_color, convert_from_color
from utils.utils_network import compute_class_weight
from utils.utils import accuracy

In [None]:
# Configure parameters
# Dataset parameters
WINDOW_SIZE = (256, 256)  # Patch size
STRIDE = 32  # Stride for testing
IN_CHANNELS = 3  # Number of input channels (RGB)
FOLDER = "../input/potsdamvaihingen/"  # Dataset path
BATCH_SIZE = 10  # Mini-batch size

# Training parameters
EPOCHS = 30  # Training epochs
SAVE_EPOCH = 10  # Save model interval
OUTPUT_FOLDER = "./output"  # Output directory
ERO_DISK_SIZE = 3  # Erosion disk size
BASE_LR = 0.01  # Base learning rate

# Create output directory
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Labels and classes
LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"]
N_CLASSES = len(LABELS)
WEIGHTS = torch.ones(N_CLASSES)
CACHE = True  # Store dataset in-memory

# Data paths
MAIN_FOLDER = FOLDER + 'ISPRS_semantic_labeling_Vaihingen/'
DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif'
LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif'

In [None]:
# Define train/test split
train_ids = ['1', '3', '23', '26', '7', '11', '13', '28', '17', '32', '34', '37']
test_ids = ['5', '21', '15', '30'] 
print(f"Training on {len(train_ids)} tiles: {train_ids}")
print(f"Testing on {len(test_ids)} tiles: {test_ids}")

In [None]:
# Define the ISPRS color palette
# ISPRS color palette
palette = {
    0: (255, 255, 255),  # Impervious surfaces (white)
    1: (0, 0, 255),      # Buildings (blue)
    2: (0, 255, 255),    # Low vegetation (cyan)
    3: (0, 255, 0),      # Trees (green)
    4: (255, 255, 0),    # Cars (yellow)
    5: (255, 0, 0),      # Clutter (red)
    6: (0, 0, 0)         # Undefined (black)
}

invert_palette = {v: k for k, v in palette.items()}

In [None]:
# Visualize a sample image and its ground truth
try:
    # Load a sample image
    img = io.imread(DATA_FOLDER.format(train_ids[0]))
    
    # Load ground truth
    gt = io.imread(LABEL_FOLDER.format(train_ids[0]))
    
    # Display
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    ax1.imshow(img)
    ax1.set_title('Sample Image (Area {})'.format(train_ids[0]))
    ax2.imshow(gt)
    ax2.set_title('Ground Truth')
    plt.show()
except Exception as e:
    print(f"Error loading sample data: {e}")
    print("Continuing with training...")

In [None]:
# Initialize datasets
print("Initializing datasets...")
train_set = ISPRS_dataset(
    ids=train_ids,
    ids_type='TRAIN',
    gt_type='full',  # 'full', 'conncomp', or 'ero'
    gt_modification=disk(ERO_DISK_SIZE),
    data_files=DATA_FOLDER,
    label_files=LABEL_FOLDER,
    window_size=WINDOW_SIZE,
    cache=CACHE,
    augmentation=True
)

In [None]:
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE)
print(f"Created data loader with {len(train_loader)} batches per epoch")

In [None]:
# Initialize the model
print("Initializing model...")
net = CRFNet(n_channels=IN_CHANNELS, n_classes=N_CLASSES, bilinear=True)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device)
if torch.cuda.is_available():
    WEIGHTS = WEIGHTS.cuda()
print(f"Model will train on: {device}")

In [None]:
# Initialize optimizer and scheduler
optimizer = optim.SGD(net.parameters(), lr=BASE_LR, momentum=0.9, weight_decay=0.0005)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1)

In [None]:
# Training function
def train_model():
    # More efficient storage size for losses
    max_iterations = EPOCHS * len(train_loader)
    losses = np.zeros(max_iterations)
    mean_losses = np.zeros(max_iterations)
    
    iter_ = 0
    
    for e in tqdm(range(1, EPOCHS + 1), desc="Epochs"):
        net.train()
        
        for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f"Epoch {e}", leave=False)):
            # Process targets for multi-scale supervision
            target_np = target.data.cpu().numpy()
            target_np = np.transpose(target_np, [1, 2, 0])
            
            # Create multi-scale targets
            scales = [(32, 32), (64, 64), (128, 128)]
            targets_resized = []
            
            for size in scales:
                targets_resized.append(
                    np.transpose(
                        cv2.resize(target_np, dsize=size, interpolation=cv2.INTER_NEAREST), 
                        [2, 0, 1]
                    )
                )
            
            # Original target back to channel-first
            target_np = np.transpose(target_np, [2, 0, 1])
            
            # Move data to device
            data = data.to(device)
            target_tensor = torch.from_numpy(target_np).to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            output, out_fc, out_neigh, _ = net(data)
            
            # Calculate main loss
            loss = CrossEntropy2d(output, target_tensor, weight=WEIGHTS)
            
            # Multi-scale losses
            fc_losses = []
            for i, t in enumerate(targets_resized):
                t_tensor = torch.from_numpy(t).type(torch.LongTensor).to(device)
                weights = compute_class_weight(t).to(device)
                fc_losses.append(CrossEntropy2d(out_fc[i], t_tensor, weight=weights))
            
            # Pairwise loss for neighborhood consistency
            pairwise_loss = CrossEntropy2d(out_neigh, target_tensor, weight=WEIGHTS)
            
            # Combine losses
            total_loss = (loss + sum(fc_losses)) / (1 + len(fc_losses)) + pairwise_loss
            
            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()
            
            # Record loss
            losses[iter_] = total_loss.item()
            mean_losses[iter_] = np.mean(losses[max(0, iter_-100):iter_+1])
            
            # Display progress every 100 iterations
            if iter_ % 100 == 0:
                # Visualize results
                with torch.no_grad():
                    # Convert to CPU for visualization
                    rgb = np.asarray(255 * np.transpose(data.cpu().numpy()[0], (1, 2, 0)), dtype='uint8')
                    pred = np.argmax(output.cpu().numpy()[0], axis=0)
                    gt = target_tensor.cpu().numpy()[0]
                    
                    # Print progress
                    acc = accuracy(pred, gt)
                    print(f'Epoch {e}/{EPOCHS} [{batch_idx}/{len(train_loader)} ({100*batch_idx/len(train_loader):.0f}%)] Loss: {total_loss.item():.4f} Acc: {acc:.2f}%')
                    
                    # Plot loss curve
                    plt.figure(figsize=(10, 4))
                    plt.plot(mean_losses[:iter_+1])
                    plt.title('Mean Loss')
                    plt.grid(True)
                    plt.xlabel('Iterations')
                    plt.ylabel('Loss')
                    plt.show()
                    
                    # Visualize predictions
                    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
                    ax1.imshow(rgb)
                    ax1.set_title('RGB Input')
                    ax2.imshow(convert_to_color(gt))
                    ax2.set_title('Ground Truth')
                    ax3.imshow(convert_to_color(pred))
                    ax3.set_title('Prediction')
                    plt.tight_layout()
                    plt.show()
            
            iter_ += 1
            
        # Update learning rate
        scheduler.step()
        
        # Save model checkpoint
        if e % SAVE_EPOCH == 0:
            torch.save(net.state_dict(), f'{OUTPUT_FOLDER}/model_epoch{e}.pth')
    
    # Save final model
    torch.save(net.state_dict(), f'{OUTPUT_FOLDER}/model_final.pth')
    print("Training completed!")

In [None]:
# Run the training
train_model()

In [None]:
# Define testing function
from net.test_network import test

def evaluate_model(model_path):
    """Evaluate the trained model on the test set"""
    print(f"Loading model from {model_path}")
    net.load_state_dict(torch.load(model_path))
    net.eval()
    
    print("Preparing test data...")
    # Load test images and labels
    test_images = [1/255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids]
    test_labels = [np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids]
    
    print("Running evaluation...")
    # Evaluate the model
    all_preds = []
    all_gts = []
    
    with torch.no_grad():
        for i, (img, gt) in enumerate(zip(test_images, test_labels)):
            print(f"Processing test image {i+1}/{len(test_images)} (Area {test_ids[i]})")
            
            # Process the ground truth
            gt_processed = convert_from_color(gt)
            all_gts.append(gt_processed)
            
            # Prepare the image tensor
            img = np.transpose(img, (2, 0, 1))  # CHW format
            
            # Slide window over the image
            pred = np.zeros(gt_processed.shape, dtype=int)
            counts = np.zeros(gt_processed.shape, dtype=int)
            
            for x in range(0, img.shape[1] - WINDOW_SIZE[0] + 1, STRIDE):
                for y in range(0, img.shape[2] - WINDOW_SIZE[1] + 1, STRIDE):
                    # Extract patch
                    patch = img[:, x:x+WINDOW_SIZE[0], y:y+WINDOW_SIZE[1]]
                    patch_tensor = torch.from_numpy(patch).unsqueeze(0).to(device)
                    
                    # Forward pass
                    outputs = net(patch_tensor)[0]
                    
                    # Get predictions
                    patch_pred = np.argmax(outputs.cpu().numpy()[0], axis=0)
                    
                    # Update prediction and counts
                    pred[x:x+WINDOW_SIZE[0], y:y+WINDOW_SIZE[1]] += patch_pred
                    counts[x:x+WINDOW_SIZE[0], y:y+WINDOW_SIZE[1]] += 1
            
            # Average predictions
            pred = np.divide(pred, counts, where=counts>0)
            all_preds.append(pred)
            
            # Visualize results
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
            ax1.imshow(np.transpose(img, (1, 2, 0)))
            ax1.set_title(f'Test Image (Area {test_ids[i]})')
            ax2.imshow(convert_to_color(gt_processed))
            ax2.set_title('Ground Truth')
            ax3.imshow(convert_to_color(pred))
            ax3.set_title('Prediction')
            plt.tight_layout()
            plt.show()

    # Calculate metrics
    from utils.utils import metrics
    print("\nComputing overall metrics...")
    metrics(
        np.concatenate([p.flatten() for p in all_preds]),
        np.concatenate([g.flatten() for g in all_gts]),
        LABELS
    )
    
    return all_preds, all_gts

In [None]:
# Optionally run evaluation
# Uncomment the lines below to evaluate the model after training
# final_model_path = f'{OUTPUT_FOLDER}/model_final.pth'
# all_preds, all_gts = evaluate_model(final_model_path)

In [None]:
# Export results (if needed)
# from utils.export_result import export_results
# 
# def save_results(predictions, ground_truths, exp_name="baseline"):
#     """Save the prediction results"""
#     export_results(
#         predictions, 
#         ground_truths, 
#         OUTPUT_FOLDER, 
#         exp_name,
#         confusionMat=True,
#         prodAccuracy=True,
#         averageAccuracy=True,
#         kappaCoeff=True,
#         title=f"Results for {exp_name}"
#     )
#     
#     # Save visualization of predictions
#     for pred, test_id in zip(predictions, test_ids):
#         img = convert_to_color(pred)
#         io.imsave(f'{OUTPUT_FOLDER}/{exp_name}_area{test_id}.png', img)
#     
#     print(f"Results saved to {OUTPUT_FOLDER}")
# 
# # Uncomment to save results after evaluation
# # save_results(all_preds, all_gts, "baseline")