In [None]:
from torchvision.transforms import ToTensor
import rasterio
from rasterio.plot import show
import numpy as np
import matplotlib.pyplot as plt
from enum import Enum
from glob import glob
from typing import Tuple, Dict
import json
import geopandas as gpd
import tifffile as tiff

In [None]:
def plot_image(path, gdf):
    image_array = tiff.imread(path)
        
    band_red = image_array[:, :, 3]  
    band_green = image_array[:, :, 2] 
    band_blue = image_array[:, :, 1]  
    band_red = np.nan_to_num(band_red, nan=0)
    band_green = np.nan_to_num(band_green, nan=0)
    band_blue = np.nan_to_num(band_blue, nan=0)
    
    rgb_image = np.dstack((band_red, band_green, band_blue))
    image = (rgb_image - np.min(rgb_image)) / (np.max(rgb_image) - np.min(rgb_image))
    # image = np.flipud(image)
    # Convert to uint8
    image = (image * 255).astype(np.uint8)
    fig, (ax1,ax2) = plt.subplots(1,2, figsize=(20, 20))
    ax1.imshow(image)
    ax2.imshow(image)
    gdf.plot(ax=ax1, column='class',legend=True,figsize=(20, 20))
    ax1.axis('off')
    ax2.axis('off')
    plt.show() 

In [None]:
with open("./data/train_annotations.json") as f:
    data = json.load(f)

idx = 50
file = f"train_{idx}.tif"
json_data = None

for img in data["images"]:
    if img["file_name"] == file:
        print(img["file_name"])
        json_data = img["annotations"]
        break

path = f"./data/train_images/train_{idx}.tif"


def convert_to_geojson(data):
  """
  Converts a list of dictionaries in the specified format to GeoJSON

  Args:
      data: A list of dictionaries containing 'class' and 'segmentation' keys

  Returns:
      A GeoJSON feature collection
  """
  features = []
  for item in data:
    polygon = []
    for i in range(0, len(item['segmentation']), 2):
      polygon.append([item['segmentation'][i], item['segmentation'][i+1]])
    features.append({
      "type": "Feature",
      "geometry": {
        "type": "Polygon",
        "coordinates": [polygon]
      },
      "properties": {"class": item['class']}
    })
  return { "type": "FeatureCollection", "features": features}


# Convert to GeoJSON
geojson_data = convert_to_geojson(json_data)
gdf = gpd.GeoDataFrame.from_features(geojson_data)
plot_image(path, gdf)