In [5]:
# !conda info -e

In [1]:
#### import statements

# general
import os, time, re, csv
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from dotenv import load_dotenv
LOAD_ENV = load_dotenv()

# api
import openai
import replicate
import requests

# interactive widgets
import ipywidgets as widgets
from ipywidgets import Box, VBox, HBox, Layout, Output, Text, Button
from IPython.display import display, clear_output
from IPython.display import Image
from IPython.core.display import HTML 


import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
random.seed(42)
RANDOM_SEED = 42

#### using Monash University account (Sensilab) ####
openai.organization = os.environ.get('OPENAI_ORG')
openai.api_key = os.environ.get('OPENAI_API_KEY')

#### text-to-image (stable diffusion)
model = replicate.models.get("stability-ai/stable-diffusion")
version = model.versions.get("db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")

In [3]:
###### COMMENTED ########

# # make sure to comment for cleaner web-interface on voila!
# !jupyter nbextension enable --py widgetsnbextension --sys-prefix
# !jupyter serverextension enable voila --sys-prefix

In [4]:
class Prompt_Assistant:
    def __init__(self, model="gpt-3.5-turbo"):
        self.chat_history = []
        self.model = model
        self.__set_premise__()
    
    def __append_chat_history__(self, role:str, content:str):
        chat_dict_item = {'role': role, 'content': content}
        self.chat_history.append(chat_dict_item)
        
    def generate_response(self, input_text:str):
        self.__append_chat_history__(role='user', content=input_text)
        
        completion = openai.ChatCompletion.create(
            model=self.model,
            messages=self.chat_history
        )
        
        response = completion.choices[0].message
        res_content = response['content']
        res_role = response['role']
        
        self.__append_chat_history__(role=res_role, content=res_content)
        return res_content
    
    def get_chat_history(self):
        return self.chat_history
    
    def get_last_chat_item(self):
        return self.chat_history[-1]
    
    def __set_premise__(self):
        df_premise = pd.read_csv('./premise.csv')
        premise = df_premise.to_dict(orient='records')
        self.chat_history.extend(premise)
        res_content_tmp = self.generate_response("Let's start again with a welcome message from you.")
        

In [5]:
# Define a chat assistant object
prompt_assistant = Prompt_Assistant()

# Define input_box widget
input_field = Text(placeholder="Type your message...")

refine_button = Button(description='Refine prompt')
refine_button.add_class("refine-button")

generate_button = Button(description="Generate prompt!")
generate_button.add_class("generate-button")

input_box = HBox(children=[input_field, refine_button, generate_button])
input_box.add_class("chat-input")

# Define header
chat_header = Output(description="Prompt Assistant")
chat_header.add_class("chat-header")
chat_header.add_class("chat-header h3")

# Define output_box
output_box = Output()
output_box.add_class("chat-body")

# Chat window
chat_window = VBox(children=[chat_header, output_box, input_box])
chat_window.add_class("body")
chat_window.add_class("chat-window")


# Define input and output widgets for storing user prompt (chat history)
output_text_box = Output(width='80%')

input_text_box = Text(
    description='Enter prompt:',
    layout=Layout(width='80%')
)

negative_prompt_box = Text(
    description='Negative prompt:',
    layout=Layout(width='80%')
)

progress_box = Output(width='80%')

randomize_toggle = widgets.Checkbox(
    value=True,
    description='Randomize image',
    disabled=False
)

placeholder_image = widgets.Output()

In [6]:
CSS = """
body {
  margin: 0;
  padding: 0;
  font-family: Arial, sans-serif;
  background-color: #f0f2f5;
}

.chat-window {
  max-width: 600px;
  margin: 20px auto;
  border: 1px solid #d3d3d3;
  border-radius: 10px;
  background-color: #fff;
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}

.chat-header {
  display: flex;
  align-items: center;
  justify-content: space-between;
  padding: 10px 20px;
  border-bottom: 1px solid #d3d3d3;
}

.chat-header h3 {
  margin: 0;
  font-weight: normal;
}

.chat-body {
  height: 400px;
  overflow-y: scroll;
  padding: 10px;
}

.bubble {
  display: inline-block;
  max-width: 80%;
  margin: 5px;
  padding: 10px;
  border-radius: 10px;
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}

.user-bubble {
  float: right;
  background-color: #e5e5ea;
  color: #1c1e21;
  align-self: flex-end;
}

.bot-bubble {
  background-color: #0078ff;
  color: #fff;
  align-self: flex-start;
}

.chat-input {
  display: flex;
  width: 100%;
  align-items: center;
  justify-content: space-between;
  border-top: 1px solid #d3d3d3;
  padding: 10px;
}

.chat-input input[type="text"] {
  margin-right: 10px;
  width: 500px;
  padding: 10px;
  border-radius: 30px;
  border: none;
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
  background-color: #f0f2f5;
  font-size: 14px;
}

.chat-input button[type="submit"] {
  border: none;
  border-radius: 50%;
  width: 40px;
  height: 40px;
  background-color: #0078ff;
  color: #fff;
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
  cursor: pointer;
  font-size: 14px;
}

.refine-button {
  border: none;
  border-radius: 10px;
  width: 160px;
  height: 32px;
  background-color: #f0f2f5;
  color: #4e4e4e;
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
  cursor: pointer;
  font-size: 14px;
}

.generate-button {
  border: none;
  border-radius: 10px;
  width: 160px;
  height: 32px;
  background-color: #56953e;
  color: #fff;
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
  cursor: pointer;
  font-size: 14px;
}
"""

display(HTML('<style>'+CSS+'</style>'))


In [10]:
# Define function to handle user input
def handle_generate(sender):
    # Get user input
    input_text = input_field.value.strip()
    
    # Display input text in chat history
    with output_box:
        display(HTML(f'<div class="bubble user-bubble">{input_text}</div>'))
        
    # Clear input text
    input_field.value = ''
        
    additional_info = 'Now generate the prompt with the available info so far.'
    
    # 1. Append user input to chat history 
    # 2. Generate response from API
    # 3. Append response to chat history
    response_text = prompt_assistant.generate_response(f"{input_text}; {additional_info}")
    
    # Display output text in chat history
    with output_box:
        display(HTML(f'<div class="bubble bot-bubble">{response_text}</div>'))

def handle_refine(sender):
    # Get user input
    input_text = input_field.value.strip()
    
    # Display input text in chat history
    with output_box:
        display(HTML(f'<div class="bubble user-bubble">{input_text}</div>'))
        
    # Clear input text
    input_field.value = ''
        
    # 1. Append user input to chat history 
    # 2. Generate response from API
    # 3. Append response to chat history
    response_text = prompt_assistant.generate_response(input_text)
    
    # Display output text in chat history
    with output_box:
        display(HTML(f'<div class="bubble bot-bubble">{response_text}</div>'))
            

# Define function to handle storing user text
def handle_store(sender):
    global RANDOM_SEED
    global image_box
    # Get user text
    user_text = input_text_box.value.strip()
    neg_prompt = negative_prompt_box.value.strip()
    
    # Clear input box
    input_text_box.value = ''
    negative_prompt_box.value = ''
    
    if randomize_toggle.value:
        RANDOM_SEED = random.randint(1, 100)
        
    # Store user text in variable
    # (for this example, we'll just print it)
    with output_text_box:
        print(f'Prompt tried: \"{user_text}\"; WITHOUT \"{neg_prompt}\"')
    with progress_box:
        clear_output()
        print('(Wait while the image is rendering...)')
    
    inputs = {
        # Input prompt
        'prompt': user_text,

        # pixel dimensions of output image
        'image_dimensions': "512x512",

        # Specify things to not see in the output
        'negative_prompt': neg_prompt,

        # Number of images to output.
        # Range: 1 to 4
        'num_outputs': 1,

        # Number of denoising steps
        # Range: 1 to 500
        'num_inference_steps': 50,

        # Scale for classifier-free guidance
        # Range: 1 to 20
        'guidance_scale': 7.5,

        # Choose a scheduler.
        'scheduler': "DPMSolverMultistep",

        # Random seed. Leave blank to randomize the seed
        'seed': RANDOM_SEED
    }

    # https://replicate.com/stability-ai/stable-diffusion/versions/db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf#output-schema
    output = version.predict(**inputs)
    img = Image(url=output[0])
    with placeholder_image:
        placeholder_image.clear_output()
        display(img)
    with progress_box:
        clear_output()

In [11]:
# Define submit and store buttons
refine_button.on_click(handle_refine)
generate_button.on_click(handle_generate)

store_button = Button(description='Generate image!')
store_button.on_click(handle_store)

In [12]:
# Display widgets
display(chat_window)
display(output_text_box, progress_box, input_text_box, negative_prompt_box,
        HBox([store_button, randomize_toggle]), placeholder_image)

with output_text_box:
    clear_output()
    
with output_box:
    clear_output()
    welcome_msg = "Hi, I'm your prompt-engineering assistant for generating an image using a text-to-image system. Please type your initial text prompt and I'll help you refine it to generate your desired image."
    display(HTML(f'<div class="bubble bot-bubble">{welcome_msg}</div>'))

VBox(children=(Output(_dom_classes=('chat-header', 'chat-header h3')), Output(outputs=({'output_type': 'displa…

Output(outputs=({'output_type': 'stream', 'text': 'Prompt tried: panda on a beach; WITHOUT \nPrompt tried: goo…

Output()

Text(value='', description='Enter prompt:', layout=Layout(width='80%'))

Text(value='', description='Negative prompt:', layout=Layout(width='80%'))

HBox(children=(Button(description='Generate image!', style=ButtonStyle()), Checkbox(value=True, description='R…

Output(outputs=({'output_type': 'display_data', 'data': {'text/html': '<img src="https://replicate.delivery/pb…

In [10]:
# %%time
# %env REPLICATE_API_TOKEN=r8_N7w3zxqkr5keFMz4wKAXC6KOnGO3Xb00zPb9X
# # replicate.Client(api_token="r8_N7w3zxqkr5keFMz4wKAXC6KOnGO3Xb00zPb9X")
# model = replicate.models.get("stability-ai/stable-diffusion")
# version = model.versions.get("db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")

# # https://replicate.com/stability-ai/stable-diffusion/versions/db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf#input
# inputs = {
#     # Input prompt
#     'prompt': "jackie chan portrait in the art style of claude monet",

#     # pixel dimensions of output image
#     'image_dimensions': "768x768",

#     # Specify things to not see in the output
#     # 'negative_prompt': ...,

#     # Number of images to output.
#     # Range: 1 to 4
#     'num_outputs': 1,

#     # Number of denoising steps
#     # Range: 1 to 500
#     'num_inference_steps': 50,

#     # Scale for classifier-free guidance
#     # Range: 1 to 20
#     'guidance_scale': 7.5,

#     # Choose a scheduler.
#     'scheduler': "DPMSolverMultistep",

#     # Random seed. Leave blank to randomize the seed
#     # 'seed': ...,
# }

# # https://replicate.com/stability-ai/stable-diffusion/versions/db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf#output-schema
# output = version.predict(**inputs)
# print(output)
# Image(url= output[0])

env: REPLICATE_API_TOKEN=r8_N7w3zxqkr5keFMz4wKAXC6KOnGO3Xb00zPb9X
['https://replicate.delivery/pbxt/zGZFdzwGEZ7pKd6ZODGf2rqmr8b2md4Uk0KCMuqnWIvOwSZIA/out-0.png']
CPU times: user 63.1 ms, sys: 13.9 ms, total: 77.1 ms
Wall time: 7.43 s
