# Load libraries

In [1]:
%matplotlib inline
# %config InlineBackend.figure_format = 'retina'

from matplotlib import pyplot as plt
from glob import glob
import os
from copy import deepcopy

import numpy as np
import cv2
from PIL import Image
from time import sleep

import torch
from torch.utils.data import Dataset
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torchvision import transforms

# Lorenz's libs
import math
import pandas as pd
import requests
from io import BytesIO
from pyproj import Proj, Transformer
import random
from tqdm import tqdm
import folium
from folium.plugins import MarkerCluster

# Define helper functions/classes

**SwisstopoTileFetcher Class**

This class facilitates fetching map tiles from the Swisstopo WMTS service. It converts geographic coordinates (latitude and longitude) into tile indices, constructs the appropriate URL for the tile image, and downloads the image. The class also provides a method to display the fetched tile image using matplotlib.

Key Methods:

*   **lat_lon_to_tile_indices():** Converts latitude and longitude to tile indices based on the zoom level.
*   **fetch_tile():** Downloads the tile image from Swisstopo.
*   **show_tile():** Displays the fetched tile image.

Parameters:



*   **longitude:** The longitude of the point for which the tile is to be fetched.
*   **latitude:** The latitude of the point for which the tile is to be fetched.
*   **zoom_level:** The zoom level for the map tile.

In [2]:
class SwisstopoTileFetcher:
    def __init__(self, longitude, latitude, zoom_level):
        self.scheme = "https"
        self.server_name = "wmts0.geo.admin.ch"  # Can be wmts0 to wmts9
        self.version = "1.0.0"
        self.layer_name = "ch.swisstopo.swissimage"
        self.style_name = "default"
        self.time = "current"
        self.tile_matrix_set = "3857"
        self.format_extension = "jpeg"
        self.longitude = longitude
        self.latitude = latitude
        self.zoom_level = zoom_level

    def lat_lon_to_tile_indices(self):
        n = 2 ** self.zoom_level
        lat_rad = math.radians(self.latitude)
        x_tile = int((self.longitude + 180.0) / 360.0 * n)
        y_tile = int((1.0 - math.log(math.tan(lat_rad) + (1 / math.cos(lat_rad))) / math.pi) / 2.0 * n)
        return x_tile, y_tile

    def fetch_tile(self):
        # Convert coordinates to tile indices
        x, y = self.lat_lon_to_tile_indices()

        # Construct the URL
        url = f"{self.scheme}://{self.server_name}/{self.version}/{self.layer_name}/{self.style_name}/{self.time}/{self.tile_matrix_set}/{self.zoom_level}/{x}/{y}.{self.format_extension}"

        # Download the tile
        response = requests.get(url)
        if response.status_code == 200:
            # Open and return the image
            image = Image.open(BytesIO(response.content))
            return image, url
        else:
            print(f"Failed to download tile. Status code: {response.status_code}")
            return None

    def show_tile(self):
        image = self.fetch_tile()
        if image:
            # Display the image
            plt.imshow(image)
            plt.axis('off')  # Hide the axis
            plt.show()

**ArealstatistikSampler Class**

This class is designed to sample geographic points from a dataset provided in LV95 coordinates and convert them to WGS84 coordinates. It reads the CSV file*, filters the data based on a specified column, and randomly selects a given number of points from each unique value in that column. The selected points are then transformed from the LV95 coordinate system to the WGS84 coordinate system.

Key Methods:


*   **lv95_to_wgs84(lon, lat):** Converts coordinates from LV95 to WGS84.
*   **sample_points():** Samples the specified number of points for each unique value in the specified column, converts their coordinates, and returns a list of these points.


Parameters:



*   **file_path:** Path to the CSV file containing the data.
*   **column_to_filter:** Column name used to filter and categorize the data.
*   **num_samples:** Number of samples to select for each unique value in the column.
*   **random_state:** Optional parameter to ensure reproducibility of random sampling.

*available on https://www.bfs.admin.ch/bfs/en/home/services/geostat/swiss-federal-statistics-geodata/land-use-cover-suitability/swiss-land-use-statistics.html

In [3]:
class ArealstatistikSampler:
    def __init__(self, file_path, column_to_filter, num_samples, random_state=None):
        self.file_path = file_path
        self.column_to_filter = column_to_filter
        self.num_samples = num_samples
        self.random_state = random_state

    def lv95_to_wgs84(self, lon, lat):
        in_proj = Proj("epsg:2056")
        out_proj = Proj("epsg:4326")
        transformer = Transformer.from_proj(in_proj, out_proj)
        lon_wgs84, lat_wgs84 = transformer.transform(lon, lat)
        return lon_wgs84, lat_wgs84

    def sample_points(self):
        # Read the CSV file into a DataFrame
        df = pd.read_csv(self.file_path, delimiter=";")

        # Filter out rows with missing values in the specified column
        df_filtered = df.dropna(subset=[self.column_to_filter])

        # Create an empty list to store the selected points
        selected_points = []

        # Set random state if provided
        if self.random_state is not None:
            random_state = self.random_state
        else:
            random_state = 42  # Default random state

        # Iterate over each unique value in the specified column
        for class_value in df_filtered[self.column_to_filter].unique():
            # Filter rows for the current class value
            class_df = df_filtered[df_filtered[self.column_to_filter] == class_value]

            # Randomly select specified number of examples for the current class value
            selected_samples = class_df.sample(n=self.num_samples, random_state=random_state)

            # Convert LV95 coordinates to WGS84 and store them in the selected_points list
            for _, row in selected_samples.iterrows():
                lon_wgs84, lat_wgs84 = self.lv95_to_wgs84(row["E_COORD"], row["N_COORD"])
                selected_points.append([lon_wgs84, lat_wgs84, class_value])

        return selected_points


Define Parameters

In [4]:
# file_path = "/content/drive/MyDrive/CAS Avanced Machine Learning/Luftbild_Colarization/ag-b-00.03-37-area-csv.csv"
file_path = "/Volumes/Ruben/datasets/land_use_data/ag-b-00.03-37-area-csv.csv"
column_to_filter = "AS18_72" #column in the dataset with the classes
num_samples = 50 #number of samples per class
random_state = 42
zoom_levels = [16, 17, 18] #zoom levels to fetch images from randomly

Collect sample points and show the spatial distribution on a map

In [5]:
# Instantiate ArealstatistikSampler and sample points
sampler = ArealstatistikSampler(file_path, column_to_filter, num_samples, random_state)
coordinates = sampler.sample_points()

# Print the number of samples collected
print("Number of samples collected:", len(coordinates))

Number of samples collected: 3600


In [6]:
len(coordinates[0])

3

In [7]:
# Generate a color map for each unique class value
unique_classes = list(set([coord[2] for coord in coordinates]))
colors = {class_value: f'#{random.randint(0, 0xFFFFFF):06x}' for class_value in unique_classes}

# Create a map centered around Switzerland
map_center = [46.8182, 8.2275]  # Approximate geographical center of Switzerland
mymap = folium.Map(location=map_center, zoom_start=8)

# Add MarkerCluster to the map
marker_cluster = MarkerCluster(
    control=True,
    maxClusterRadius=30
).add_to(mymap)

# Add points to the marker cluster with colors based on class
for lat, lon, class_value in coordinates:
    folium.CircleMarker(
        location=[lat, lon],
        radius=5,
        color=colors[class_value],
        fill=True,
        fill_color=colors[class_value],
        fill_opacity=0.6,
        popup=f'Coordinates: {lat}, {lon}<br>Class: {class_value}'
    ).add_to(marker_cluster)

# Save and display the map
mymap.save('map.html')
mymap

In [8]:
def fetch_images_with_random_zoom_levels(sampled_points, zoom_levels, save_to = None):
    """
    Fetch images for sampled points using random zoom levels.

    Args:
        sampled_points (list): List of sampled points, where each point is represented as a list [lat, lon] or [lat, lon, class_value].
        zoom_levels (list): List of zoom levels to choose from.
        save_to (str): if provided, valid path where to save the  fetched image.

    Returns:
        list: List of dictionaries, each containing fetched image and its metadata (lat, lon, zoom_level, class).
    """
    fetched_images = []
    total_points = len(sampled_points)

    for indx, point in enumerate(tqdm(sampled_points, desc="Fetching Images", total=total_points)):
        lat, lon = point[:2]  # Extract latitude and longitude
        class_value = point[2] if len(point) > 2 else None
        zoom_level = random.choice(zoom_levels)
        tile_fetcher = SwisstopoTileFetcher(lon, lat, zoom_level)
        image, url = tile_fetcher.fetch_tile()
        image_data = {
            'img_id': indx,
            'img_name': '_'.join(url.split("/")[-4:]).split(".")[0],
            'image': image,
            'latitude': lat,
            'longitude': lon,
            'zoom_level': zoom_level,
            'class': class_value,
            'link':url, 
        }
        fetched_images.append(image_data)
        if save_to:
            if isinstance(save_to, str) and os.path.exists(save_to):
                data_path = os.path.join(save_to, "data")
                if not os.path.exists(data_path):
                    os.mkdir(data_path)
                # print("saving images")
                image.save(os.path.join(data_path, f"img_id_{image_data['img_id']}.jpg"))
                # sleep(0.2)

    if save_to:
        print("saving metadata")
        my_df = pd.DataFrame([ {key: d[key] for key in d if key != "image"} for d in fetched_images])
        my_df.to_csv(os.path.join(save_to, "metadata.csv"))
            




    return fetched_images



In [9]:
def show_images_grouped_by_class(coordinates, fetched_images, num_images_per_class=3):
    """
    Show a specified number of images from each class in the same row.

    Args:
        coordinates (list): List of coordinates with optional class values. Each entry is a list [lat, lon] or [lat, lon, class_value].
        fetched_images (list): List of fetched images corresponding to the coordinates.
        num_images_per_class (int): Number of images to display per class. Default is 3.
    """
    class_images = {}

    for point, image_data in zip(coordinates, fetched_images):
        class_value = image_data['class'] if image_data['class'] is not None else 'No Class'
        if class_value not in class_images:
            class_images[class_value] = [image_data]
        else:
            class_images[class_value].append(image_data)

    for class_value, images in class_images.items():
        plt.figure(figsize=(15, 3))
        plt.suptitle(f'Class: {class_value}', fontsize=14)
        for i in range(min(num_images_per_class, len(images))):
            plt.subplot(1, num_images_per_class, i + 1)
            plt.imshow(images[i]['image'])
            plt.title(
                f"Zoom: {images[i]['zoom_level']}\nLat: {images[i]['latitude']}\nLon: {images[i]['longitude']}",
                fontsize=10
            )
            plt.axis('off')
        plt.tight_layout()
        plt.show()



Example without a dataset

In [10]:
# Example for one image
# https://wmts100.geo.admin.ch/1.0.0/ch.swisstopo.swissimage/default/current/2056/28/7043/6890.jpeg
# https://wmts100.geo.admin.ch/1.0.0/ch.swisstopo.swissimage/default/current/2056/27/2817/2756.jpeg
# https://wmts100.geo.admin.ch/1.0.0/ch.swisstopo.swissimage/default/current/2056/26/1408/1377.jpeg
img_coord = [
    (46.41360594778221, 7.4425600819620765, 40),  # With class value
    (46.71360594778221, 7.4425600819620765)       # Without class value
]

zoom_levels_example = [17, 18]


example_images = fetch_images_with_random_zoom_levels(img_coord, zoom_levels_example)



Fetching Images: 100%|██████████| 2/2 [00:00<00:00, 13.59it/s]


In [11]:
example_images

[{'img_id': 0,
  'img_name': '3857_17_68245_46412',
  'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256>,
  'latitude': 46.41360594778221,
  'longitude': 7.4425600819620765,
  'zoom_level': 17,
  'class': 40,
  'link': 'https://wmts0.geo.admin.ch/1.0.0/ch.swisstopo.swissimage/default/current/3857/17/68245/46412.jpeg'},
 {'img_id': 1,
  'img_name': '3857_17_68245_46253',
  'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256>,
  'latitude': 46.71360594778221,
  'longitude': 7.4425600819620765,
  'zoom_level': 17,
  'class': None,
  'link': 'https://wmts0.geo.admin.ch/1.0.0/ch.swisstopo.swissimage/default/current/3857/17/68245/46253.jpeg'}]

In [12]:

Show images grouped by class
show_images_grouped_by_class(img_coord, example_images)

Fetch the images from the given coordinates in the dataset

In [13]:
# Fetch the images
# Fetch the images

path_to_save_raw_images = r"/Volumes/Ruben/datasets/fetched_data_again"
# path_to_save_raw_images = r"/Volumes/Ruben/datasets/fetched_raw_imgs_via_api"

fetched_images = fetch_images_with_random_zoom_levels(coordinates, zoom_levels, save_to=path_to_save_raw_images)

Fetching Images: 100%|██████████| 3600/3600 [13:57<00:00,  4.30it/s]


saving metadata


In [48]:
len(fetched_images)

3600

In [None]:
item['image'].save

In [57]:
[item['image'].save(os.path.join(path_to_save_raw_images, "data", f"img_id_{indx}.jpg")) for indx, item in enumerate(fetched_images)]

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,

In [15]:
fetched_images[0]

{'img_id': '3857_18_137090_91987',
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256>,
 'latitude': 47.20178026276514,
 'longitude': 8.264822308080678,
 'zoom_level': 18,
 'class': 61,
 'link': 'https://wmts0.geo.admin.ch/1.0.0/ch.swisstopo.swissimage/default/current/3857/18/137090/91987.jpeg'}

In [22]:
jpg_files = sorted(glob(os.path.join(path_to_save_raw_images, "*", "*.jpg")))
len(jpg_files)

3594

In [42]:
jpg_files_ids = ["_".join(my_file.split("/")[-1].split("ImageId_")[-1].split("_")[1:]) for my_file in jpg_files]
len(jpg_files_ids)
# len(sorted(list(set(jpg_files_ids))))
# jpg_files_ids[0]


3594

In [43]:
len(sorted(list(set(jpg_files_ids))))

3586

In [29]:
jpg_files_ids[0]

'3857_16_33868_23254'

In [30]:
fetched_images[0]['img_id']

'3857_17_68545_45993'

In [26]:
img_not_fetched = [ img for img in fetched_images if img["img_id"] in jpg_files_ids]
len(img_not_fetched)

3153

In [44]:
img_not_fetched[0]["img_id"]

'3857_18_137650_93036'

In [33]:
len([ img for img in fetched_images if img["img_id"] in jpg_files_ids])

3152