In [7]:
import sys
from io import BytesIO
import PIL.Image
import torch
from PyQt5.QtWidgets import QApplication, QMainWindow, QAction, QFileDialog, QColorDialog, QInputDialog, QWidget, \
    QStatusBar
from PyQt5.QtGui import QPainter, QPen, QImage, QPixmap
from PyQt5.QtCore import Qt, QPoint, QByteArray, QBuffer, QIODevice, QThread, pyqtSignal
from diffusers import StableDiffusionImg2ImgPipeline
from collections import deque


def pil_image_to_qimage(pil_image):
    byte_data = BytesIO()
    pil_image.save(byte_data, format="PNG")
    q_image = QImage()
    q_image.loadFromData(byte_data.getvalue())
    return q_image


def qimage_to_pil_image(q_image):
    byte_data = QByteArray()
    buffer = QBuffer(byte_data)
    buffer.open(QIODevice.WriteOnly)
    q_image.save(buffer, "PNG")
    pil_image = PIL.Image.open(BytesIO(byte_data.data()))
    return pil_image


class StableDiffusionThread(QThread):
    result_ready = pyqtSignal(QImage)

    def __init__(self, pipeline, prompt, input_image):
        super().__init__()
        self.pipeline = pipeline
        self.prompt = prompt
        self.input_image = input_image

    def run(self):
        response = self.pipeline(prompt=self.prompt, image=self.input_image, strength=0.6, guidance_scale=7.5)
        output_image = pil_image_to_qimage(response.images[0])
        self.result_ready.emit(output_image)


class FloodFillThread(QThread):
    result_ready = pyqtSignal(QPixmap)

    def __init__(self, image, point, target_color, fill_color):
        super().__init__()
        self.image = image
        self.point = point
        self.target_color = target_color
        self.fill_color = fill_color

    def run(self):
        image = self.image

        def within_bounds(p):
            return 0 <= p.x() < image.width() and 0 <= p.y() < image.height()

        def flood_fill(seed_point):
            queue = deque([seed_point])
            while queue:
                n = queue.popleft()
                if image.pixelColor(n) != self.target_color:
                    continue
                image.setPixelColor(n, self.fill_color)
                for delta in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                    neighbor = QPoint(n.x() + delta[0], n.y() + delta[1])
                    if within_bounds(neighbor):
                        queue.append(neighbor)

        flood_fill(self.point)
        output_image = QPixmap.fromImage(image)
        self.result_ready.emit(output_image)


class DrawingApp(QMainWindow):
    def __init__(self):
        super().__init__()

        model_id = "stabilityai/stable-diffusion-2"
        self.pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        self.pipeline.to("cuda")

        self.is_processing_sd = False  # Flag to track Stable Diffusion processing state

        self.undo_stack = []  # Stack to keep track of undo states
        self.redo_stack = []  # Stack to keep track of redo states

        self.init_ui()

    def init_ui(self):
        self.setWindowTitle('Drawing App')

        # Menu Bar
        menubar = self.menuBar()

        # File Menu
        file_menu = menubar.addMenu('File')
        load_action = QAction('Load', self)
        load_action.triggered.connect(self.load_image)
        save_action = QAction('Save', self)
        save_action.triggered.connect(self.save_image)
        file_menu.addAction(load_action)
        file_menu.addAction(save_action)

        # Edit Menu
        edit_menu = menubar.addMenu('Edit')
        undo_action = QAction('Undo', self)
        undo_action.triggered.connect(self.undo)
        redo_action = QAction('Redo', self)
        redo_action.triggered.connect(self.redo)
        edit_menu.addAction(undo_action)
        edit_menu.addAction(redo_action)

        # Draw Menu
        draw_menu = menubar.addMenu('Draw')
        pen_color_action = QAction('Set Pen Color', self)
        pen_color_action.triggered.connect(self.set_pen_color)
        pen_width_action = QAction('Set Pen Width', self)
        pen_width_action.triggered.connect(self.set_pen_width)
        sd_action = QAction('Apply StableDiffusion', self)
        sd_action.triggered.connect(self.apply_stable_diffusion)
        self.fill_action = QAction('Fill with Color', self)
        self.fill_action.setCheckable(True)
        self.fill_action.triggered.connect(self.enable_fill)
        draw_menu.addAction(pen_color_action)
        draw_menu.addAction(pen_width_action)
        draw_menu.addAction(sd_action)
        draw_menu.addAction(self.fill_action)

        # Status Bar
        self.status_bar = QStatusBar()
        self.setStatusBar(self.status_bar)

        # Set central widget
        self.canvas = Canvas(self)
        self.setCentralWidget(self.canvas)

        self.fill_mode = False
        self.canvas.fill_completed.connect(self.reset_fill_action)

        self.show()

    def load_image(self):
        options = QFileDialog.Options()
        file_name, _ = QFileDialog.getOpenFileName(self, "Load Image", "",
                                                   "All Files (*);;Image Files (*.png *.jpg *.bmp)", options=options)
        if file_name:
            self.canvas.load_image(QImage(file_name))
            self.undo_stack.append(self.canvas.image.copy())  # Save state for undo

    def save_image(self):
        options = QFileDialog.Options()
        file_name, _ = QFileDialog.getSaveFileName(self, "Save Image", "", "PNG Files (*.png);;All Files (*)",
                                                   options=options)
        if file_name:
            self.canvas.save_image(file_name)

    def set_pen_color(self):
        color_dialog = QColorDialog(self)
        
        # Calculate the position to center the dialog over the main window
        dialog_x = self.x() + (self.width() - color_dialog.width()) // 4
        dialog_y = self.y() + (self.height() - color_dialog.height()) // 2
        color_dialog.move(dialog_x, dialog_y)
        
        if color_dialog.exec_() == QColorDialog.Accepted:
            color = color_dialog.selectedColor()
            if color.isValid():
                self.canvas.set_pen_color(color)

    def set_pen_width(self):
        width, ok = QInputDialog.getInt(self, 'Pen Width', 'Enter pen width:', 1, 1, 50, 1)
        if ok:
            self.canvas.set_pen_width(width)

    def apply_stable_diffusion(self):
        if self.is_processing_sd:
            self.status_bar.showMessage("Stable Diffusion is already processing. Please wait.")
            return
        prompt, ok = QInputDialog.getText(self, 'StableDiffusion', 'Enter prompt:')
        if ok and prompt:
            self.status_bar.showMessage("Processing image with Stable Diffusion...")
            self.is_processing_sd = True
            
            input_image = qimage_to_pil_image(self.canvas.image.toImage())
            
            self.sd_thread = StableDiffusionThread(self.pipeline, prompt, input_image)
            self.sd_thread.result_ready.connect(self.update_image)
            self.sd_thread.finished.connect(self.reset_sd_flag)
            self.sd_thread.start()

    def update_image(self, output_image):
        self.status_bar.clearMessage()
        self.undo_stack.append(self.canvas.image.copy())  # Save state for undo
        self.redo_stack.clear()  # Clear redo stack on new action
        self.canvas.load_image(output_image)

    def reset_sd_flag(self):
        self.is_processing_sd = False

    def enable_fill(self):
        self.canvas.fill_mode = self.fill_action.isChecked()

    def reset_fill_action(self):
        self.fill_action.setChecked(False)

    def undo(self):
        if self.undo_stack:
            self.redo_stack.append(self.canvas.image.copy())  # Save current state to redo stack
            self.canvas.image = self.undo_stack.pop()  # Restore the last state from undo stack
            self.canvas.update()

    def redo(self):
        if self.redo_stack:
            self.undo_stack.append(self.canvas.image.copy())  # Save current state to undo stack
            self.canvas.image = self.redo_stack.pop()  # Restore the last state from redo stack
            self.canvas.update()


class Canvas(QWidget):
    fill_completed = pyqtSignal()

    def __init__(self, parent=None):
        super().__init__(parent)
        self.setFixedSize(800, 600)
        self.image = QPixmap(self.size())
        self.image.fill(Qt.white)
        self.drawing = False
        self.last_point = QPoint()
        self.pen_color = Qt.black  # Default pen color
        self.pen_width = 4  # Default pen width
        self.fill_mode = False

    def paintEvent(self, event):
        canvas_painter = QPainter(self)
        canvas_painter.drawPixmap(self.rect(), self.image, self.image.rect())

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            if self.fill_mode:
                self.parent().undo_stack.append(self.image.copy())  # Save state for undo
                self.fill_color(event.pos())
                self.fill_mode = False

            else:
                self.drawing = True
                self.parent().undo_stack.append(self.image.copy())  # Save state for undo
                self.parent().redo_stack.clear()  # Clear redo stack on new action
                self.last_point = event.pos()

    def mouseMoveEvent(self, event):
        if event.buttons() & Qt.LeftButton and self.drawing:
            painter = QPainter(self.image)
            painter.setPen(QPen(self.pen_color, self.pen_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin))
            painter.drawLine(self.last_point, event.pos())
            self.last_point = event.pos()
            self.update()

    def mouseReleaseEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.drawing = False

    def save_image(self, path):
        self.image.save(path)

    def load_image(self, qimage):
        self.image = QPixmap.fromImage(qimage)
        self.update()

    def clear_canvas(self):
        self.image.fill(Qt.white)
        self.update()

    def set_pen_color(self, color):
        self.pen_color = color

    def set_pen_width(self, width):
        self.pen_width = width

    def fill_color(self, point):
        target_color = self.image.toImage().pixelColor(point)
        fill_color = self.pen_color
        image = self.image.toImage()
        if target_color == fill_color:
            return

        self.flood_fill_thread = FloodFillThread(image, point, target_color, fill_color)
        self.flood_fill_thread.result_ready.connect(self.update_fill_image)
        self.flood_fill_thread.start()

    def update_fill_image(self, output_image):
        self.parent().status_bar.clearMessage()
        self.image = output_image
        self.update()
        self.fill_completed.emit()


def run_app():
    app = QApplication.instance()
    if app is None:
        app = QApplication(sys.argv)
    window = DrawingApp()
    window.show()

    app.exec_()


# Run the application
run_app()


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]