This notebook provides the octopuses for the final state of geodesics of pre-executed Riemannian K-means. The plots are used in Section 5.3 of my thesis.

# Octopus plotting

In [None]:
import torch, ricci_regularization, json, os
import cv2 #to make videos

In [None]:
setting_numbers = [1] #[4]
experiment_numbers = [1] #[2,3]
k_means_setup_numbers = [10]#[0,1,2,3,4,5,6,7,8,9]
for setting_number in setting_numbers:
    for experiment_number in experiment_numbers:
        for k_means_setup_number in k_means_setup_numbers:
            pretrained_AE_setting_name = f'MNIST_Setting_{setting_number}_exp{experiment_number}'
            Path_clustering_setup = f"../../experiments/{pretrained_AE_setting_name}/K_means_setup_{k_means_setup_number}"
            with open(Path_clustering_setup + f"/params.json", "r") as f_Riemannian:
                Riemannian_k_means_params = json.load(f_Riemannian)
            geodesic_curve = torch.load(Path_clustering_setup+"/geodesic_curve.pt")
            ricci_regularization.RiemannianKmeansTools.plot_octopus(geodesic_curve,
                memberships=torch.tensor(Riemannian_k_means_params["Riemannian_k_means_labels"]),
                ground_truth_labels=Riemannian_k_means_params["ground_truth_labels"],
                saving_folder=Path_clustering_setup, suffix=0, show_geodesics_in_original_local_charts=False,
                show_only_geodesics_to_nearest_centroids = False, periodicity_mode=True)

# Video

In [None]:
# Specify the directory containing PNGs and the output video name
images_folder = "../../plots/Kmeans"
output_video = "../../plots/Kmeans_video.avi"

# Set video parameters
frame_rate = 30
images = sorted([img for img in os.listdir(images_folder) if img.endswith(".png")])
if not images:
    raise ValueError("No PNG images found in the specified directory.")


In [None]:

# Read the first image to get dimensions
first_image_path = os.path.join(images_folder, images[0])
frame = cv2.imread(first_image_path)
height, width, layers = frame.shape

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')  # You can use other codecs like 'mp4v'
video = cv2.VideoWriter(output_video, fourcc, frame_rate, (width, height))

for image in images:
    img_path = os.path.join(images_folder, image)
    frame = cv2.imread(img_path)
    video.write(frame)

video.release()
cv2.destroyAllWindows()
print(f"Video saved as {output_video}")


# GIF

In [None]:
import os
import imageio

# Full path to save GIF
output_gif_path = "../../plots/Kmeans_gif.gif"
frame_duration = 0.05  # Seconds per frame (e.g., 0.1s = 10 fps)

# Get and sort image file names
images = sorted(os.listdir(images_folder))

# Read images
frames = []
for image_name in images:
    img_path = os.path.join(images_folder, image_name)
    frame = imageio.imread(img_path)
    frames.append(frame)

# Save as GIF
imageio.mimsave(output_gif_path, frames, duration=frame_duration)
print(f"GIF saved as {output_gif_path}")
