In [5]:
import tkinter as tk
from tkinter import ttk, filedialog, colorchooser, Canvas
from PIL import Image, ImageTk, ImageOps, ImageDraw
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F

# Simple colorization model
class SimpleColorizationNet(nn.Module):
    def __init__(self):
        super(SimpleColorizationNet, self).__init__()
        # Encoder
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)

        # Decoder with conditional input
        self.deconv1 = nn.ConvTranspose2d(256 + 3, 128, kernel_size=2, stride=2)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)

        self.final_conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)

    def forward(self, grayscale, condition_map):
        # Encode grayscale
        x = F.relu(self.conv1(grayscale))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)

        # Resize condition map to match encoder output
        condition_resized = F.interpolate(condition_map, size=x.shape[2:])

        # Combine with condition map
        x = torch.cat([x, condition_resized], dim=1)

        # Decode
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        x = torch.sigmoid(self.final_conv(x))

        return x

class ConditionalColorizationApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Conditional Image Colorization")
        self.root.geometry("1000x700")

        # Initialize model
        self.model = SimpleColorizationNet()

        # Current image and mask
        self.original_image = None
        self.grayscale_image = None
        self.condition_mask = None
        self.color_conditions = {}  # mask_id: (color, mask_points)
        self.current_color = "#FF0000"  # Default red
        self.drawing = False
        self.current_mask_points = []
        self.current_mask_id = 0

        # Setup GUI
        self.setup_ui()

    def setup_ui(self):
        # Main frames
        control_frame = ttk.Frame(self.root, padding="10")
        control_frame.grid(row=0, column=0, sticky=(tk.W, tk.E))

        image_frame = ttk.Frame(self.root, padding="10")
        image_frame.grid(row=1, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))

        # Control panel
        ttk.Button(control_frame, text="Load Image",
                  command=self.load_image).pack(side=tk.LEFT, padx=5)

        ttk.Button(control_frame, text="Choose Color",
                  command=self.choose_color).pack(side=tk.LEFT, padx=5)

        self.color_display = Canvas(control_frame, width=30, height=30,
                                   bg=self.current_color)
        self.color_display.pack(side=tk.LEFT, padx=5)

        ttk.Label(control_frame, text="Brush Size:").pack(side=tk.LEFT, padx=5)
        self.brush_size = tk.IntVar(value=20)
        ttk.Spinbox(control_frame, from_=5, to=50, width=10,
                   textvariable=self.brush_size).pack(side=tk.LEFT, padx=5)

        ttk.Button(control_frame, text="Start Drawing",
                  command=self.start_drawing).pack(side=tk.LEFT, padx=5)

        ttk.Button(control_frame, text="Clear All",
                  command=self.clear_all).pack(side=tk.LEFT, padx=5)

        ttk.Button(control_frame, text="Colorize",
                  command=self.colorize_image).pack(side=tk.LEFT, padx=5)

        ttk.Button(control_frame, text="Save Result",
                  command=self.save_image).pack(side=tk.LEFT, padx=5)

        # Condition list
        ttk.Label(control_frame, text="Conditions:").pack(side=tk.LEFT, padx=5)
        self.condition_listbox = tk.Listbox(control_frame, height=1, width=20)
        self.condition_listbox.pack(side=tk.LEFT, padx=5)

        # Image display areas
        self.canvas_original = Canvas(image_frame, width=400, height=400, bg="gray")
        self.canvas_original.grid(row=0, column=0, padx=5, pady=5)
        ttk.Label(image_frame, text="Original / Mask").grid(row=1, column=0)

        self.canvas_result = Canvas(image_frame, width=400, height=400, bg="gray")
        self.canvas_result.grid(row=0, column=1, padx=5, pady=5)
        ttk.Label(image_frame, text="Colorized Result").grid(row=1, column=1)

        # Bind events for drawing
        self.canvas_original.bind("<Button-1>", self.start_draw_event)
        self.canvas_original.bind("<B1-Motion>", self.draw_event)
        self.canvas_original.bind("<ButtonRelease-1>", self.stop_draw_event)

    def load_image(self):
        file_path = filedialog.askopenfilename(
            filetypes=[("Image files", "*.jpg *.jpeg *.png *.bmp")]
        )
        if file_path:
            self.original_image = Image.open(file_path).convert("RGB")
            self.grayscale_image = self.original_image.convert("L")

            # Create empty condition mask
            self.condition_mask = Image.new("RGB", self.original_image.size, (128, 128, 128))

            # Clear previous conditions
            self.color_conditions.clear()
            self.current_mask_id = 0
            self.condition_listbox.delete(0, tk.END)

            # Display images
            self.display_images()

    def choose_color(self):
        color = colorchooser.askcolor(title="Choose color for condition")[1]
        if color:
            self.current_color = color
            self.color_display.config(bg=color)

    def start_drawing(self):
        self.drawing = True
        self.current_mask_points = []

    def start_draw_event(self, event):
        if self.drawing and self.original_image:
            self.current_mask_points = [(event.x, event.y)]
            self.draw_on_mask(event.x, event.y)

    def draw_event(self, event):
        if self.drawing and self.original_image and self.current_mask_points:
            self.current_mask_points.append((event.x, event.y))
            self.draw_on_mask(event.x, event.y)

    def stop_draw_event(self, event):
        if self.drawing and self.current_mask_points and self.original_image:
            self.current_mask_id += 1
            self.color_conditions[self.current_mask_id] = (
                self.current_color,
                self.current_mask_points.copy()
            )

            # Add to condition list
            self.condition_listbox.insert(tk.END,
                f"Condition {self.current_mask_id}: {self.current_color}")

            self.current_mask_points = []

    def draw_on_mask(self, x, y):
        if not self.original_image:
            return

        # Convert color from hex to RGB
        color_rgb = self.hex_to_rgb(self.current_color)

        # Create drawing context
        draw = ImageDraw.Draw(self.condition_mask)

        # Draw circle at position (scaled to image size)
        img_width, img_height = self.original_image.size
        canvas_width = self.canvas_original.winfo_width()
        canvas_height = self.canvas_original.winfo_height()

        scale_x = img_width / canvas_width
        scale_y = img_height / canvas_height

        img_x = int(x * scale_x)
        img_y = int(y * scale_y)
        brush_size = int(self.brush_size.get() * scale_x)

        draw.ellipse([img_x - brush_size, img_y - brush_size,
                     img_x + brush_size, img_y + brush_size],
                    fill=color_rgb)

        self.display_images()

    def hex_to_rgb(self, hex_color):
        hex_color = hex_color.lstrip('#')
        return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

    def display_images(self):
        if self.original_image:
            # Display original with mask overlay
            display_img = self.original_image.copy()
            mask_display = self.condition_mask.copy()
            mask_display = mask_display.resize(self.original_image.size)

            # Blend original with mask
            blended = Image.blend(display_img.convert("RGB"),
                                 mask_display.convert("RGB"), 0.3)

            # Resize for display
            display_size = (400, 400)
            blended_display = blended.resize(display_size, Image.Resampling.LANCZOS)

            # Convert to PhotoImage
            self.photo_original = ImageTk.PhotoImage(blended_display)
            self.canvas_original.create_image(0, 0, anchor=tk.NW,
                                            image=self.photo_original)

            # Display grayscale in result area initially
            grayscale_display = self.grayscale_image.resize(display_size,
                                                          Image.Resampling.LANCZOS)
            self.photo_result = ImageTk.PhotoImage(grayscale_display)
            self.canvas_result.create_image(0, 0, anchor=tk.NW,
                                          image=self.photo_result)

    def clear_all(self):
        if self.original_image:
            self.condition_mask = Image.new("RGB", self.original_image.size,
                                          (128, 128, 128))
            self.color_conditions.clear()
            self.current_mask_id = 0
            self.condition_listbox.delete(0, tk.END)
            self.display_images()

    def colorize_image(self):
        if not self.original_image or not self.color_conditions:
            return

        # Convert images to arrays
        img_array = np.array(self.grayscale_image)
        condition_array = np.array(self.condition_mask)

        # Simple colorization logic (for demonstration)
        # In a real implementation, you would use the trained model here
        result = self.simple_colorization_logic(img_array, condition_array)

        # Display result
        result_img = Image.fromarray(result)
        display_size = (400, 400)
        result_display = result_img.resize(display_size, Image.Resampling.LANCZOS)

        self.photo_result = ImageTk.PhotoImage(result_display)
        self.canvas_result.create_image(0, 0, anchor=tk.NW, image=self.photo_result)

        self.result_image = result_img

    def simple_colorization_logic(self, grayscale, condition_mask):
        # This is a simplified demonstration
        # A real implementation would use the neural network

        # Convert grayscale to RGB
        height, width = grayscale.shape
        result = np.zeros((height, width, 3), dtype=np.uint8)

        # Copy grayscale to all channels
        result[:, :, 0] = grayscale
        result[:, :, 1] = grayscale
        result[:, :, 2] = grayscale

        # Apply color conditions based on mask
        mask_indices = np.any(condition_mask != [128, 128, 128], axis=-1)

        if np.any(mask_indices):
            # Blend condition colors with grayscale
            alpha = 0.7  # Blend factor
            result[mask_indices] = (
                alpha * condition_mask[mask_indices] +
                (1 - alpha) * result[mask_indices]
            ).astype(np.uint8)

        return result

    def save_image(self):
        if hasattr(self, 'result_image'):
            file_path = filedialog.asksaveasfilename(
                defaultextension=".png",
                filetypes=[("PNG files", "*.png"),
                          ("JPEG files", "*.jpg"),
                          ("All files", "*.*")]
            )
            if file_path:
                self.result_image.save(file_path)

def main():
    root = tk.Tk()
    app = ConditionalColorizationApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()

Exception in Tkinter callback
Traceback (most recent call last):
  File "/usr/lib/python3.10/tkinter/__init__.py", line 1921, in __call__
    return self.func(*args)
  File "/tmp/ipykernel_136788/2927982310.py", line 143, in load_image
    self.display_images()
  File "/tmp/ipykernel_136788/2927982310.py", line 224, in display_images
    blended_display = blended.resize(display_size, Image.Resampling.LANCZOS)
  File "/usr/lib/python3/dist-packages/PIL/Image.py", line 65, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'PIL.Image' has no attribute 'Resampling'
