In [1]:
import os


project_dir = '/blue/shenhaowang/qingqisong/GenerativeUrbanDesign-main'
os.chdir(project_dir)

In [2]:
!pip install cmake lit

Defaulting to user installation because normal site-packages is not writeable


In [3]:
!pip install gradio==3.16.2 \
albumentations==1.3.0 \
opencv-contrib-python \
imageio==2.9.0 \
imageio-ffmpeg==0.4.2 \
pytorch-lightning==1.5.0 \
omegaconf==2.1.1 \
test-tube>=0.7.5 \
streamlit==1.12.1 \
einops\
transformers \
webdataset==0.2.5 \
kornia==0.6 \
open_clip_torch==2.0.2 \
invisible-watermark>=0.1.5 \
streamlit-drawable-canvas==0.8.0 \
torchmetrics==0.6.0 \
timm==0.6.12 \
addict==2.4.0 \
yapf==0.32.0 \
prettytable==3.6.0 \
safetensors \
basicsr==1.4.2

In [4]:
!pip install open_clip_torch

Defaulting to user installation because normal site-packages is not writeable


In [5]:
!pip install --upgrade timm

Defaulting to user installation because normal site-packages is not writeable
Collecting timm
  Using cached timm-1.0.15-py3-none-any.whl.metadata (52 kB)
Using cached timm-1.0.15-py3-none-any.whl (2.4 MB)
Installing collected packages: timm
  Attempting uninstall: timm
    Found existing installation: timm 0.6.12
    Uninstalling timm-0.6.12:
      Successfully uninstalled timm-0.6.12
Successfully installed timm-1.0.15


In [6]:
import cv2
import numpy as np
import os
import pandas as pd
from torch.utils.data import Dataset, ConcatDataset

class MyDataset(Dataset):
    def __init__(self, image_dir, osm_dir, prompt_file, size=512):
        """
        Initialize the dataset
        Args:
            image_dir (str): Directory containing satellite images
            osm_dir (str): Directory containing OSM maps
            prompt_file (str): Path to the CSV file containing prompts
            size (int): Target size for images (both height and width)
        """
        self.image_dir = image_dir
        self.osm_dir = osm_dir
        self.size = size
        
        # Read CSV file
        self.df = pd.read_csv(prompt_file)
        
        # Extract required information
        self.image_list_xtile = self.df['latitude'].tolist()
        self.image_list_ytile = self.df['longitude'].tolist()
        self.idx_list = self.df['idx'].tolist() if 'idx' in self.df.columns else list(range(len(self.df)))
        self.prompts = self.df['prompt'].tolist() if 'prompt' in self.df.columns else [''] * len(self.df)

    def __len__(self):
        return len(self.image_list_xtile)

    def __getitem__(self, idx):
        try:
            xtile = self.image_list_xtile[idx]
            ytile = self.image_list_ytile[idx]
            item_idx = self.idx_list[idx]
            
            # Construct file paths
            source_filename = os.path.join(self.osm_dir, f"control_{item_idx}_{xtile}_{ytile}.png")
            target_filename = os.path.join(self.image_dir, f"combined_{item_idx}_{xtile}_{ytile}.png")
            
            # Read images
            source = cv2.imread(source_filename)
            target = cv2.imread(target_filename)
            
            if source is None:
                raise FileNotFoundError(f"Source image not found: {source_filename}")
            if target is None:
                raise FileNotFoundError(f"Target image not found: {target_filename}")
            
            # Resize images to consistent size
            source = cv2.resize(source, (self.size, self.size), interpolation=cv2.INTER_AREA)
            target = cv2.resize(target, (self.size, self.size), interpolation=cv2.INTER_AREA)
            
            # Handle alpha channel in source image if present
            if source.shape[2] == 4:
                trans_mask = source[:, :, 3] == 0
                source[trans_mask] = [255, 255, 255, 255]
                source = cv2.cvtColor(source, cv2.COLOR_BGRA2BGR)
            
            # Convert BGR to RGB
            source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
            target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
            
            # Normalize
            source = source.astype(np.float32) / 255.0  # [0, 1]
            target = (target.astype(np.float32) / 127.5) - 1.0  # [-1, 1]
            
            return {
                'jpg': target,
                'txt': self.prompts[idx],
                'hint': source
            }
            
        except Exception as e:
            print(f"Error loading item {idx}: {str(e)}")
            return {
                'jpg': np.zeros((self.size, self.size, 3), dtype=np.float32),
                'txt': f"Error: {str(e)}",
                'hint': np.zeros((self.size, self.size, 3), dtype=np.float32)
            }

In [None]:

from share import *

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict


dataset1_config = {
    "image_dir": "./combinedgrido150/",
    "osm_dir": "./combinedcontrol_mapsgrido150/",
    "prompt_file": "./metrics_datagrido150/dualtrainprompt.csv"

}

dataset2_config = {
    "image_dir": "./combinedgridc180/",
    "osm_dir": "./combinedcontrol_mapsgridc180/",
    "prompt_file": "./metrics_datagridc180/dualtrainprompt.csv"
}

# Configs
#resume_path = './models/control_sd15_ini.ckpt'
resume_path = './lightning_logs/version_5227827/checkpoints/epoch=1-step=15275.ckpt'

batch_size = 4
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False


dataset1 = MyDataset(**dataset1_config)
dataset2 = MyDataset(**dataset2_config)


combined_dataset = ConcatDataset([dataset1, dataset2])


dataloader = DataLoader(combined_dataset, num_workers=0, batch_size=batch_size, shuffle=True)


model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control


logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(max_epochs=3, gpus=1, precision=32, callbacks=[logger])


trainer.fit(model, dataloader)

  from .autonotebook import tqdm as notebook_tqdm


logging improved.
No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Loaded model config from [./models/cldm_v15.yaml]
Loaded state_dict from [./lightning_logs/version_5227827/checkpoints/epoch=1-step=15275.ckpt]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
  rank_zero_deprecation(
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Set SLURM handle signals.

  | Name              | Type               | Params
---------------------------------------------------------
0 | model             | DiffusionWrapper   | 859 M 
1 | first_stage_model | AutoencoderKL      | 83.7 M
2 | cond_stage_model  | FrozenCLIPEmbedder | 123 M 
3 | control_model     | ControlNet         | 361 M 
---------------------------------------------------------
1.2 B     Trainable params
206 M     Non-trainable params
1.4 B     Total params
5,710.058 Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/7638 [00:00<?, ?it/s] 



Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps
Epoch 0:   4%|▍         | 300/7638 [05:50<2:22:58,  1.17s/it, loss=0.106, v_num=5265894, train/loss_simple_step=0.132, train/loss_vlb_step=0.000606, train/consistency_loss_step=0.953, train/loss_step=0.133, global_step=299.0]   Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps
Epoch 0:   8%|▊         | 600/7638 [11:40<2:16:56,  1.17s/it, loss=0.0929, v_num=5265894, train/loss_simple_step=0.0692, train/loss_vlb_step=0.000266, train/consistency_loss_step=0.549, train/loss_step=0.0697, global_step=599.0] Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps
Epoch 0:  12%|█▏        | 900/7638 [17:30<2:11:05,  1.17s/it, loss=0.0914, v_num=5265894, train/loss_simple_step=0.0557, train/loss_vlb_step=0.000193, train/consistency_loss_step=0.479, train/loss_step=0.0562, global_step=899.0] Data shape for DDIM sampli

In [None]:
import numpy as np
from PIL import Image
import os
from glob import glob
from tqdm import tqdm

# Define standard color dictionary for land use categories
COLOR_MAPPING = {
    'residential': (255, 202, 101),
    'commercial': (255, 128, 128),
    'industrial': (191, 191, 191),
    'retail': (255, 85, 85),
    'parking': (239, 239, 239),
    'school': (255, 240, 170),
    'university': (255, 240, 170),
    'hospital': (255, 230, 230),
    'park': (178, 216, 178),
    'garden': (178, 216, 178),
    'recreation_ground': (184, 230, 184),
    'playground': (204, 230, 204),
    'sports_centre': (204, 230, 204),
    'stadium': (204, 230, 204),
    'pitch': (204, 230, 204),
    'golf_course': (178, 216, 178),
    'forest': (140, 191, 140),
    'wood': (140, 191, 140),
    'grass': (204, 255, 204),
    'grassland': (204, 255, 204),
    'meadow': (204, 255, 204),
    'heath': (204, 230, 204),
    'scrub': (184, 230, 184),
    'wetland': (186, 230, 230),
    'water': (179, 217, 255),
    'beach': (255, 245, 204),
    'sand': (255, 245, 204),
    'farmland': (255, 255, 204),
    'orchard': (230, 255, 179),
    'vineyard': (230, 255, 179),
    'cemetery': (209, 207, 205),
    'white': (255, 255, 255)  
}

def classify_pixel(pixel, color_mapping):
    """
    Classify a pixel to the closest standard color using Euclidean distance
    
    Args:
        pixel: Input pixel RGB values
        color_mapping: Dictionary of standard colors
    
    Returns:
        RGB values of the closest standard color
    """
    min_distance = float('inf')
    closest_category = None
    
    # Calculate distance to each standard color
    for category, standard_color in color_mapping.items():
        # Calculate Euclidean distance between pixel and standard color
        distance = np.sqrt(sum((p - s) ** 2 for p, s in zip(pixel, standard_color)))
        
        # Update if this is the closest color so far
        if distance < min_distance:
            min_distance = distance
            closest_category = category
    
    return COLOR_MAPPING[closest_category]

def convert_to_standard_colors(input_path, output_path):
    """
    Convert input land use image to standardized colors
    
    Args:
        input_path: Path to input image file
        output_path: Path where output image will be saved
    """
    try:
        # Read input image
        img = Image.open(input_path)
        img_array = np.array(img)
        
        # Create output array with same dimensions
        output_array = np.zeros_like(img_array)
        
        # Process each pixel in the image
        height, width = img_array.shape[:2]
        for y in range(height):
            for x in range(width):
                pixel = img_array[y, x]
                # Convert pixel to nearest standard color
                standard_color = classify_pixel(pixel, COLOR_MAPPING)
                output_array[y, x] = standard_color
        
        # Save the standardized image
        output_img = Image.fromarray(output_array.astype('uint8'))
        output_img.save(output_path)
        return True
    except Exception as e:
        print(f"Error processing {input_path}: {str(e)}")
        return False

def process_folder(input_folder, output_folder):
    """
    Process all osm_ images in the input folder
    
    Args:
        input_folder: Path to folder containing input images
        output_folder: Path to folder where output images will be saved
    """
    # Create output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)
    
    # Get all osm_ images in input folder
    osm_images = glob(os.path.join(input_folder, "osm_*.png"))
    
    # Process each image
    successful = 0
    failed = 0
    
    print(f"Found {len(osm_images)} images to process")
    
    for input_path in tqdm(osm_images, desc="Processing images"):
        # Create output path
        filename = os.path.basename(input_path)
        output_path = os.path.join(output_folder, f"standardized_{filename}")
        
        # Process the image
        if convert_to_standard_colors(input_path, output_path):
            successful += 1
        else:
            failed += 1
    
    print(f"\nProcessing complete:")
    print(f"Successfully processed: {successful} images")
    print(f"Failed to process: {failed} images")

# Example usage
input_folder = "./m6fid_test_resultsextracted"  
output_folder = "./standardizedm6fid"  
process_folder(input_folder, output_folder)

In [None]:
import numpy as np
import cv2
from scipy import ndimage
import matplotlib.pyplot as plt
from collections import defaultdict
import os
import glob

def evaluate_land_use_image(image_path, color_similarity_threshold=5):
    """
    Evaluate a standardized land use image based on diversity, density, and design.
    Modified version with better diagnostics for block detection.

    Parameters:
    - image_path: Path to the standardized land use image
    - color_similarity_threshold: Threshold for color matching (use small values for standardized images)

    Returns:
    - Dictionary with evaluation metrics
    - List of color masks
    - Building mask
    """
    # Load the image
    img = cv2.imread(image_path)
    if img is None:
        print(f"Failed to load image: {image_path}")
        return None, None, None
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Get image dimensions
    height, width, _ = img.shape
    total_pixels = height * width

    # 1. Diversity: Extract unique colors from standardized image
    unique_colors, counts = np.unique(img.reshape(-1, 3), axis=0, return_counts=True)
    
    # Filter colors based on significance
    significant_colors = []
    significant_counts = []
    
    # Lower threshold for standardized images (0.1% of total pixels)
    min_pixel_threshold = total_pixels * 0.001
    
    for color, count in zip(unique_colors, counts):
        # Skip white/very light colors (likely roads or background)
        if np.mean(color) > 240:
            continue
        
        # Only keep colors that appear in a meaningful portion of the image
        if count >= min_pixel_threshold:
            significant_colors.append(color)
            significant_counts.append(count)
    
    # Sort by count (descending) for consistency
    sorted_indices = np.argsort(significant_counts)[::-1]
    significant_colors = [significant_colors[i] for i in sorted_indices]
    significant_counts = [significant_counts[i] for i in sorted_indices]
    
    # Debug information
    print(f"Found {len(unique_colors)} total unique colors")
    print(f"Found {len(significant_colors)} significant colors after filtering")
    
    # Print top colors for debugging
    print("\nTop significant colors:")
    for i, (color, count) in enumerate(zip(significant_colors[:10], significant_counts[:10])):
        print(f"  Color {i}: RGB{tuple(color)} - {count} pixels ({count/total_pixels*100:.2f}%)")

    # 2. Density: Calculate building coverage
    building_mask = np.zeros((height, width), dtype=np.uint8)

    # Define color ranges for typical buildings in land use maps
    # Pink/Red buildings (residential, commercial)
    pink_lower = np.array([180, 100, 100])
    pink_upper = np.array([255, 180, 180])
    pink_mask = np.all((img >= pink_lower) & (img <= pink_upper), axis=2)

    # Yellow/Orange buildings (schools, hospitals, etc.)
    yellow_lower = np.array([180, 150, 50])
    yellow_upper = np.array([255, 230, 150])
    yellow_mask = np.all((img >= yellow_lower) & (img <= yellow_upper), axis=2)

    # Gray buildings (industrial)
    gray_lower = np.array([150, 150, 150])
    gray_upper = np.array([200, 200, 200])
    gray_mask = np.all((img >= gray_lower) & (img <= gray_upper), axis=2)

    # Combine all building masks
    building_mask = np.logical_or(np.logical_or(pink_mask, yellow_mask), gray_mask).astype(np.uint8)
    building_pixel_count = np.sum(building_mask)
    building_coverage_ratio = building_pixel_count / total_pixels

    # Store building colors for visualization
    building_colors = []
    for color in significant_colors:
        r, g, b = color
        # Check if color is in building range
        if (r >= 180 and 100 <= g <= 180 and 100 <= b <= 180) or \
           (r >= 180 and 150 <= g <= 230 and 50 <= b <= 150) or \
           (150 <= r <= 200 and 150 <= g <= 200 and 150 <= b <= 200):
            building_colors.append(color)

    # 3. Design: Calculate average block size
    avg_block_sizes = []
    color_masks = []
    block_counts = []
    all_block_info = []  # Store detailed info for each color
    
    # Lower minimum block size threshold for better detection
    min_block_size_threshold = 200  # Reduced from 600 to 50 pixels
    
    #print(f"\nAnalyzing blocks (min size threshold: {min_block_size_threshold} pixels):")

    # Reshape the image once for efficient color matching
    reshaped_img = img.reshape(-1, 3)
    
    for i, color in enumerate(significant_colors):
        print(f"\nColor {i} (RGB: {color}):")
        
        # DIRECT SOLUTION: Use the same method as in np.unique to create masks
        # Match pixels exactly using array comparison
        color_matches = np.all(reshaped_img == color, axis=1)
        # Reshape back to original image dimensions
        color_mask = color_matches.reshape(height, width).astype(np.uint8)
        
        # Store the mask for later use
        color_masks.append(color_mask)
        
        # Check how many pixels match this color
        matched_pixels = np.sum(color_mask)
        print(f"  Matched pixels: {matched_pixels} ({matched_pixels/total_pixels*100:.2f}%)")

        # Label connected components
        labeled_mask, num_features = ndimage.label(color_mask)
        print(f"  Total connected components: {num_features}")

        if num_features > 0:
            # Measure properties of labeled regions
            region_sizes = ndimage.sum(color_mask, labeled_mask, range(1, num_features + 1))
            
            # Show distribution of region sizes
            if len(region_sizes) > 0:
                print(f"  Region size stats: min={np.min(region_sizes):.0f}, "
                      f"max={np.max(region_sizes):.0f}, mean={np.mean(region_sizes):.0f}")
            
            # Filter out very small regions (noise)
            valid_regions = region_sizes[region_sizes > min_block_size_threshold]
            
            if len(valid_regions) > 0:
                avg_block_size = np.mean(valid_regions)
                avg_block_sizes.append(avg_block_size)
                block_counts.append(len(valid_regions))
                
               # print(f"  Valid blocks (>{min_block_size_threshold}px): {len(valid_regions)}")
               # print(f"  Average block size: {avg_block_size:.2f} pixels")
                
                # Store detailed info
                all_block_info.append({
                    'color': color.tolist(),
                    'valid_blocks': len(valid_regions),
                    'avg_size': avg_block_size,
                    'total_components': num_features
                })
            else:
                print(f"  No blocks larger than threshold ({min_block_size_threshold}px)")
                print(f"  Largest block was: {np.max(region_sizes):.0f} pixels")
                
                # Try with a lower threshold to see if blocks exist
                test_threshold = 10
                test_valid = region_sizes[region_sizes > test_threshold]
                if len(test_valid) > 0:
                    print(f"  With threshold={test_threshold}px: {len(test_valid)} blocks found")

    # Calculate overall average block size
    if avg_block_sizes:
        overall_avg_block_size = np.mean(avg_block_sizes)
       # print(f"\nOverall average block size: {overall_avg_block_size:.2f} pixels")
       # print(f"Total colors with valid blocks: {len(avg_block_sizes)}")
    else:
        overall_avg_block_size = 0
        print("\nWARNING: No valid blocks found! Average block size is 0.")
        print("Consider:")
        print("1. Reducing min_block_size_threshold")
        print("2. Increasing color_similarity_threshold")
        print("3. Checking if colors are forming connected regions")

    # Create results dictionary
    results = {
        "diversity": {
            "unique_land_use_categories": len(significant_colors),
            "colors": [color.tolist() for color in significant_colors],
            "land_use_distribution": {f"color_{i}": count/sum(significant_counts)
                                     for i, count in enumerate(significant_counts)}
        },
        "density": {
            "building_coverage_ratio": building_coverage_ratio,
            "building_colors": [color.tolist() for color in building_colors],
            "building_pixel_count": int(building_pixel_count),
            "total_pixels": total_pixels
        },
        "design": {
            "average_block_size": overall_avg_block_size,
            "block_size_by_category": {f"color_{i}": size for i, size in enumerate(avg_block_sizes)},
            "block_counts": {f"color_{i}": count for i, count in enumerate(block_counts)},
            "min_block_size_threshold": min_block_size_threshold,
            "detailed_block_info": all_block_info
        }
    }

    return results, color_masks, building_mask

def visualize_results(image_path, results, color_masks, building_mask, save_path=None):
    """
    Visualize the evaluation results
    
    Parameters:
    - image_path: Path to the original image
    - results: Evaluation results dictionary
    - color_masks: List of color masks
    - building_mask: Building mask
    - save_path: Path to save the visualization (if None, display instead)
    """
    img = cv2.imread(image_path)
    if img is None:
        print(f"Failed to load image for visualization: {image_path}")
        return None
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(16, 12))

    # Original image
    plt.subplot(2, 3, 1)
    plt.imshow(img)
    plt.title("Original Land Use Image")
    plt.axis('off')

    # Building mask
    plt.subplot(2, 3, 2)
    plt.imshow(building_mask, cmap='binary')
    plt.title(f"Building Coverage: {results['density']['building_coverage_ratio']:.2f}")
    plt.axis('off')

    # Color segmentation
    plt.subplot(2, 3, 3)
    segmentation = np.zeros_like(img)
    colors = results["diversity"]["colors"]

    for i, mask in enumerate(color_masks):
        if i < len(colors):
            color = colors[i]
            for c in range(3):
                segmentation[:,:,c] = np.where(mask == 1, color[c], segmentation[:,:,c])

    plt.imshow(segmentation)
    plt.title(f"Color Segmentation: {len(colors)} categories")
    plt.axis('off')


# Diversity visualization
    plt.subplot(2, 3, 4)
    categories = results["diversity"]["unique_land_use_categories"]
    distribution = list(results["diversity"]["land_use_distribution"].values())

    bars = plt.bar(range(len(distribution)), distribution)
    plt.title(f"Land Use Diversity: {categories} categories")
    plt.xlabel("Land Use Category")
    plt.ylabel("Proportion")

    # Color the bars with their corresponding land use colors
    for i, bar in enumerate(bars):
        if i < len(colors):
            bar.set_color([c/255 for c in colors[i]])

    # Density visualization
    plt.subplot(2, 3, 5)
    building_ratio = results["density"]["building_coverage_ratio"]

    plt.pie([building_ratio, 1-building_ratio],
            labels=["Buildings", "Non-buildings"],
            autopct='%1.1f%%')
    plt.title(f"Building Density: {building_ratio:.2f}")

    # Design visualization
    plt.subplot(2, 3, 6)
    avg_block_size = results["design"]["average_block_size"]
    block_sizes = list(results["design"]["block_size_by_category"].values())

    if block_sizes:
        bars = plt.bar(range(len(block_sizes)), block_sizes)
        plt.title(f"Design: Avg Block Size {avg_block_size:.2f}")
        plt.xlabel("Land Use Category")
        plt.ylabel("Average Block Size (pixels)")

        # Color the bars with their corresponding land use colors
        for i, bar in enumerate(bars):
            if i < len(colors):
                bar.set_color([c/255 for c in colors[i]])

    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

    return segmentation

def process_image_folder(folder_path, output_folder=None, extension="*.png"):
    """
    Process all images in a folder
    
    Parameters:
    - folder_path: Path to the folder containing images
    - output_folder: Path to save visualization results (if None, display instead)
    - extension: File extension to look for
    
    Returns:
    - Dictionary with results for all images
    """
    # Create output folder if specified
    if output_folder and not os.path.exists(output_folder):
        os.makedirs(output_folder)
        os.makedirs(os.path.join(output_folder, "visualizations"), exist_ok=True)
    
    # Get all image files in the folder
    image_files = glob.glob(os.path.join(folder_path, extension))
    
    if not image_files:
        print(f"No {extension} files found in {folder_path}")
        return {}
    
    all_results = {}
    
    # Process each image
    for img_path in image_files:
        print(f"Processing {img_path}...")
        
        # Get base filename for output
        base_name = os.path.basename(img_path)
        file_name = os.path.splitext(base_name)[0]
        
        # Evaluate the image
        results, color_masks, building_mask = evaluate_land_use_image(img_path)
        
        if results is None:
            print(f"Skipping {img_path} due to processing error")
            continue
            
        all_results[base_name] = results
        
        # Visualize and save or display results
        if output_folder:
            vis_path = os.path.join(output_folder, "visualizations", f"{file_name}_analysis.png")
            visualize_results(img_path, results, color_masks, building_mask, save_path=vis_path)
        else:
            print(f"\nResults for {base_name}:")
            print(f"Diversity: {results['diversity']['unique_land_use_categories']} unique land use categories")
            print(f"Density: Building coverage ratio = {results['density']['building_coverage_ratio']:.2f}")
            print(f"Design: Average block size = {results['design']['average_block_size']:.2f} pixels")
            
            # Show visualization
            visualize_results(img_path, results, color_masks, building_mask)
    
    # Save all results to a file if output folder is specified
    if output_folder:
        import json
        with open(os.path.join(output_folder, "all_results.json"), 'w') as f:
            json.dump(all_results, f, indent=2)
    
    return all_results

# Example usage
if __name__ == "__main__":
    # Process a single image
    # image_path = "./extracted_osmcubic/osm_map_3.png"
    # results, color_masks, building_mask = evaluate_land_use_image(image_path)
    # print("Evaluation Results:")
    # print(f"Diversity: {results['diversity']['unique_land_use_categories']} unique land use categories")
    # print(f"Density: Building coverage ratio = {results['density']['building_coverage_ratio']:.2f}")
    # print(f"Design: Average block size = {results['design']['average_block_size']:.2f} pixels")
    # visualize_results(image_path, results, color_masks, building_mask)
    
    # Process all images in a folder
    folder_path = "./standardized_the1m6"
    output_folder = "./anastandardized_the1m6"
    
    # Supports PNG files by default, change extension for other file types
    all_results = process_image_folder(folder_path, output_folder, extension="*.png")
    
    print(f"Processed {len(all_results)} images. Results saved to {output_folder}")

In [4]:
import numpy as np
import os
import time
import torch
from torchvision import models, transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import random
import shutil

class ImageDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.files = [f for f in os.listdir(path) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.path, self.files[idx])
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

def extract_features(dataloader, model, device):
    model.eval()
    features = []
    
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            feat = model(batch)
            features.append(feat.cpu().numpy())
    
    features = np.concatenate(features, axis=0)
    return features

def calculate_fid(real_features, gen_features):
   
    mu1 = np.mean(real_features, axis=0)
    mu2 = np.mean(gen_features, axis=0)
    
  
    sigma1 = np.cov(real_features, rowvar=False)
    sigma2 = np.cov(gen_features, rowvar=False)
    
    
    diff = mu1 - mu2
    
   
    eps = 1e-6
    sigma1 = sigma1 + np.eye(sigma1.shape[0]) * eps
    sigma2 = sigma2 + np.eye(sigma2.shape[0]) * eps
    
   
    covmean_sq = sigma1.dot(sigma2)
    
    
    if np.iscomplexobj(covmean_sq):
        print("warning")
        covmean_sq = covmean_sq.real
    
   
    eigvals = np.linalg.eigvals(covmean_sq)
    eigvals = np.maximum(eigvals, 0)  
    covmean_trace = np.sum(np.sqrt(eigvals))
    
    
    fid = np.sum(diff**2) + np.trace(sigma1) + np.trace(sigma2) - 2 * covmean_trace
    
    return fid


def main():
    
    real_merged_dir = "./merged6_real_images"
    gen_merged_dir = "./merged66_generated_images"
    
    
    temp_real_dir = "./temp_real_images"
    temp_gen_dir = "./temp_gen_images"
    
    try:
        
        os.makedirs(temp_real_dir, exist_ok=True)
        os.makedirs(temp_gen_dir, exist_ok=True)
        
        
        real_images = [f for f in os.listdir(real_merged_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        gen_images = [f for f in os.listdir(gen_merged_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        
        
        min_count = min(len(real_images), len(gen_images))
        
        
        selected_real = random.sample(real_images, min_count)
        selected_gen = random.sample(gen_images, min_count)
        
        
        for img in selected_real:
            shutil.copy(os.path.join(real_merged_dir, img), os.path.join(temp_real_dir, img))
        
        for img in selected_gen:
            shutil.copy(os.path.join(gen_merged_dir, img), os.path.join(temp_gen_dir, img))
        
        print(f"已平衡两个文件夹的图片数量至 {min_count} 张")
        
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Inception
        model = models.inception_v3(pretrained=True, transform_input=False)
        
        model.fc = torch.nn.Identity()
        model = model.to(device)
        
        
        transform = transforms.Compose([
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        
        real_dataset = ImageDataset(temp_real_dir, transform)
        gen_dataset = ImageDataset(temp_gen_dir, transform)
        
        real_loader = DataLoader(real_dataset, batch_size=32, shuffle=False, num_workers=4)
        gen_loader = DataLoader(gen_dataset, batch_size=32, shuffle=False, num_workers=4)
        
        
        start_time = time.time()
        
        
        print("提取真实图像特征...")
        real_features = extract_features(real_loader, model, device)
        print("提取生成图像特征...")
        gen_features = extract_features(gen_loader, model, device)
        
        # FID
        print("计算FID值...")
        fid_value = calculate_fid(real_features, gen_features)
        
        
        elapsed_time = time.time() - start_time
        
        print(f"FID值: {fid_value}")
        print(f"计算耗时: {elapsed_time:.2f} 秒")
    
    finally:
        
        if os.path.exists(temp_real_dir):
            shutil.rmtree(temp_real_dir)
        if os.path.exists(temp_gen_dir):
            shutil.rmtree(temp_gen_dir)

if __name__ == "__main__":
    main()

已平衡两个文件夹的图片数量至 1999 张
提取真实图像特征...
提取生成图像特征...
计算FID值...
FID值: 58.58846480018747
计算耗时: 36.31 秒
