In [None]:
import os
import torch
import torch.nn.functional as F
from torchvision.transforms import transforms
from PIL import Image, ImageTk
import tkinter as tk
from tkinter import filedialog, ttk
import numpy as np
from matplotlib import cm
from utils import CustomUNet

class Gui:
    def __init__(self, root, device, models_folder="models"):
        self.root = root
        self.device = device
        self.models_folder = models_folder
        self.models = self.load_model_names()
        self.current_model = None

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.image_path = None

        self.root.title("GUI")
        self.root.geometry("1200x700")
        self.root.configure(bg="#f0f4f8")

        header = tk.Label(self.root, text="GUI", font=("Lato", 24, "bold"), bg="#f0f4f8", fg="#333")
        header.pack(pady=20)

        self.model_var = tk.StringVar(self.root)
        self.model_var.set("Select a model")
        model_label = tk.Label(self.root, text="Choose a model:", font=("Lato", 12), bg="#f0f4f8", fg="#555")
        model_label.pack(pady=5)
        self.model_dropdown = ttk.Combobox(self.root, textvariable=self.model_var, values=self.models, state="readonly", font=("Lato", 12))
        self.model_dropdown.bind("<<ComboboxSelected>>", self.change_model)
        self.model_dropdown.pack(pady=10)

        self.upload_btn = tk.Button(self.root, text="Upload Image", command=self.upload_image,
                                    font=("Lato", 12), bg="#007BFF", fg="white", relief=tk.FLAT, padx=12, pady=8)
        self.upload_btn.config(activebackground="#0056b3", activeforeground="white", borderwidth=0)
        self.upload_btn.pack(pady=10)

        frame = tk.Frame(self.root, bg="#f0f4f8")
        frame.pack(padx=30, pady=20, expand=True, fill=tk.BOTH)

        self.canvas_original = tk.Canvas(frame, width=500, height=500, bg="#ffffff", highlightthickness=2, highlightbackground="#ccc")
        self.canvas_original.pack(side=tk.LEFT, padx=20, pady=20, expand=True, fill=tk.BOTH)
        self.canvas_original.create_text(250, 250, text="Original Image", font=("Lato", 16, "italic"), fill="#aaa")

        self.canvas_segmented = tk.Canvas(frame, width=500, height=500, bg="#ffffff", highlightthickness=2, highlightbackground="#ccc")
        self.canvas_segmented.pack(side=tk.RIGHT, padx=20, pady=20, expand=True, fill=tk.BOTH)
        self.canvas_segmented.create_text(250, 250, text="Segmented Image", font=("Lato", 16, "italic"), fill="#aaa")

        # Add a legend frame
        self.legend_frame = tk.Frame(self.root, bg="#f0f4f8")
        self.legend_frame.pack(pady=10)

    def load_model_names(self):
        """Load available model names from the models folder."""
        if not os.path.exists(self.models_folder):
            os.makedirs(self.models_folder)
        return [f for f in os.listdir(self.models_folder) if f.endswith(".pth")]

    def change_model(self, event=None):
        """Change the current model and reset canvases."""
        model_name = self.model_var.get()
        model_path = os.path.join(self.models_folder, model_name)
        try:
            self.current_model = CustomUNet(n_classes=21).to(self.device)
            self.current_model.load_state_dict(torch.load(model_path, map_location=self.device))
            self.current_model.eval()
            print(f"Loaded model: {model_name}")

            self.canvas_original.delete("all")
            self.canvas_segmented.delete("all")
            self.canvas_original.create_text(250, 250, text="Original Image", font=("Lato", 16, "italic"), fill="#aaa")
            self.canvas_segmented.create_text(250, 250, text="Segmented Image", font=("Lato", 16, "italic"), fill="#aaa")
        except Exception as e:
            print(f"Error loading model {model_name}: {e}")

    def upload_image(self):
        self.image_path = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg;*.jpeg;*.png")])
        if self.image_path:
            self.display_image(self.image_path)
            self.run_segmentation()

    def display_image(self, path):
        img = Image.open(path)
        img = img.resize((500, 500))
        self.original_img = ImageTk.PhotoImage(img)
        self.canvas_original.delete("all")
        self.canvas_original.create_image(0, 0, anchor=tk.NW, image=self.original_img)

    def run_segmentation(self):
        if not self.image_path or not self.current_model:
            print("Please select a model and upload an image.")
            return

        img = Image.open(self.image_path)
        input_tensor = self.transform(img).unsqueeze(0).to(self.device)

        with torch.no_grad():
            output = self.current_model(input_tensor)
            prediction = output.argmax(dim=1).squeeze().cpu().numpy()

        colormap = cm.get_cmap('jet', 21)
        colored_prediction = (colormap(prediction / prediction.max()) * 255).astype(np.uint8)
        segmented_img = Image.fromarray(colored_prediction[:, :, :3]).resize((500, 500))

        segmented_img = segmented_img.convert("RGBA")
        overlay = Image.new("RGBA", segmented_img.size, (255, 255, 255, 0))
        segmented_img = Image.blend(overlay, segmented_img, alpha=0.7)

        img = img.resize((500, 500)).convert("RGBA")
        combined_img = Image.alpha_composite(img, segmented_img)
        self.segmented_img = ImageTk.PhotoImage(combined_img)

        self.canvas_segmented.delete("all")
        self.canvas_segmented.create_image(0, 0, anchor=tk.NW, image=self.segmented_img)

        
        self.update_legend(prediction)

    def update_legend(self, prediction):
       
        
        for widget in self.legend_frame.winfo_children():
            widget.destroy()

        
        unique_classes = np.unique(prediction)

        
        for class_idx in unique_classes:
            class_name = classes[class_idx]
            color = voc_colormap[class_idx]

            # Create a colored rectangle and class name label
            color_rect = tk.Canvas(self.legend_frame, width=20, height=20, bg=self.rgb_to_hex(color), highlightthickness=0)
            color_rect.pack(side=tk.LEFT, padx=5, pady=5)
            class_label = tk.Label(self.legend_frame, text=class_name, font=("Lato", 12), bg="#f0f4f8", fg="#333")
            class_label.pack(side=tk.LEFT, padx=5, pady=5)

    def rgb_to_hex(self, rgb):
        """Convert RGB tuple to hexadecimal color code."""
        return '#{:02x}{:02x}{:02x}'.format(*rgb)


classes = [
    "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", 
    "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", 
    "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]

voc_colormap = [
    (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
    (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
    (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
    (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root = tk.Tk()
app = Gui(root, device)
root.mainloop()

  self.current_model.load_state_dict(torch.load(model_path, map_location=self.device))


Loaded model: model.pth


  colormap = cm.get_cmap('jet', 21)


Loaded model: Adam_CrossEntropy.pth
