In [1]:
%%capture --no-stderr
%pip install -U langchain tavily-python langgraph matplotlib langchain_community langchain-openai scikit-learn langchainhub langchain-ollama tiktoken langchain-nomic chromadb gpt4all firecrawl-py

In [2]:
import os

os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
os.environ['LANGCHAIN_API_KEY'] = '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'

In [3]:
import os
from langchain_ollama import ChatOllama
from langchain_ollama import OllamaEmbeddings

local_llm = 'llama3.1'
def get_llm():
    return ChatOllama(model=local_llm, temperature=0)

In [4]:
from langchain_ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

prompt = PromptTemplate(
    template="""
    You are an assistant that will receiver two inputs. The first one is a prompt that a user gave to an AI meme-generator,
    and the second input is a refined prompt created by an AI.
    If you believe that the AI generated a prompt that is not related to the initial user prompt, take the user prompt and refine the prompt to make it
    a one-sentence funny meme-like caption that references parts of the user query, or else jsut return the AI-Generated Query. Your return value should
    contain no preamble or explanation.

    User Query: {query}
    \n \n
    AI-Generated Query: {generated_query}
    \n \n

    Answer:
    """,
    input_variables=["query, generated_query"]
)

llm = ChatOllama(model=local_llm, temperature=0)

prompt_chain = prompt | llm | StrOutputParser()

In [None]:
from typing_extensions import TypedDict, List
from PIL import Image
from IPython.display import Image, display
from langgraph.graph import START, END, StateGraph
from generate_memes import generate_image
from generate_labels import generate_meme_label_wrapper

class AgentState(TypedDict):

    user_prompt: str
    generated_prompt: str
    image_path: str

def accept_prompt(state: AgentState) -> dict:
    user_input = input("Give me some information on a meme you would like to see? ")
    return {"user_prompt": user_input}

def fix_prompt(state: AgentState) -> dict:
    user_prompt = state['user_prompt']
    distil_bert_prompt = generate_meme_label_wrapper(user_prompt)
    result = prompt_chain.invoke({'query': user_prompt, 'generated_query': distil_bert_prompt})
    
    print(result)
    return {"generated_prompt": result}
    

def make_image(state: AgentState) -> dict:
    model_path = "trained-comp-vis-model"
    output_dir = "generated_memes"
    output_path = generate_image(
        state['generated_prompt'], 
        model_path, 
        output_dir, 
        num_inference_steps=200,
        guidance_scale=8.5,
        width=768,
        height=768
    )
    return {"image_path": output_path}

In [None]:
# Graph 
workflow = StateGraph(AgentState)

# Define the nodes
workflow.add_node('accept_prompt', accept_prompt)
workflow.add_node('fix_prompt', fix_prompt)
workflow.add_node('make_image', make_image)

# Build Graph
workflow.set_entry_point('accept_prompt')
workflow.add_edge('accept_prompt', 'fix_prompt')
workflow.add_edge('fix_prompt', 'make_image')
workflow.add_edge('make_image', END)

custom_graph = workflow.compile()
display(Image(custom_graph.get_graph(xray=True).draw_mermaid_png()))

In [None]:
custom_graph.invoke({'user_prompt': ''})