In [23]:
# loading the OpenAI API key
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(), override=True)

from typing import List, Sequence
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import BaseMessage, HumanMessage
from langgraph.graph import END, MessageGraph

from langchain_openai import AzureChatOpenAI


In [13]:
reflection_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a viral Instagram influencer grading a caption for a Instagram post. Generate critique and recommendation for the user's caption."
            "Always provide detailed recommendations, including requests for lengths, virality, style, etc."
        ),
        MessagesPlaceholder(variable_name="messages")
    ]
)

generation_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a viral Instagram influencer assistant tasked with writing excellent, catchy Instagram post captions"
            "Generate the best Instagram caption possible for the user's request."
            "If the user provides critique, respond with a revised version of your previous attempts.",
        ),
        MessagesPlaceholder(variable_name="messages")
    ]
)


In [21]:
llm = AzureChatOpenAI(
        azure_deployment="gpt-4o-mini",  # or your deployment
        api_version="2023-03-15-preview",  # or your api version
        temperature=0)

In [22]:
generate_chain = generation_prompt | llm 
reflect_chain = reflection_prompt | llm 

In [7]:
REFLECT = 'reflect'
GENERATE = 'generate'

In [24]:
def generation_node(state : Sequence[BaseMessage]):
    return generate_chain.invoke({"messages" : state})

In [25]:
def reflection_node(state : Sequence[BaseMessage]) -> List[BaseMessage]:
    res = reflect_chain.invoke({"messages" : state})
    return [HumanMessage(content=res.content)]

In [26]:
def should_continue(state :List[BaseMessage]):
    if len(state) > 6:
        return END
    return REFLECT

In [40]:
from mermaid import Mermaid

builder = MessageGraph()
builder.add_node(GENERATE, generation_node)
builder.add_node(REFLECT, reflection_node)
builder.set_entry_point(GENERATE)
builder.add_conditional_edges(GENERATE,should_continue)
builder.add_edge(REFLECT, GENERATE)

graph = builder.compile()

Mermaid(graph.get_graph().draw_mermaid())

In [41]:
input = HumanMessage(content="Write a instagram post caption for trip photos from meghalaya vacation")

In [42]:
graph.invoke(input)

[HumanMessage(content='Write a instagram post caption for trip photos from meghalaya vacation', additional_kwargs={}, response_metadata={}, id='172fc124-764a-4b76-b184-aed7daa85683'),
 AIMessage(content='"Chasing waterfalls and clouds in the enchanting hills of Meghalaya! 🌧️✨ From breathtaking views to hidden gems, every moment felt like a dream. Who\'s ready to explore this paradise with me? 🌿💚 #MeghalayaDiaries #Wanderlust"', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 56, 'prompt_tokens': 65, 'total_tokens': 121, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'gpt-4o-mini', 'system_fingerprint': 'fp_878413d04d', 'finish_reason': 'stop', 'logprobs': None, 'content_filter_results': {}}, id='run-5e70d51e-7303-4a84-bb2b-38bc1edbbfc6-0', usage_metadata={'input_tokens': 65, 'output_tokens': 56, 'total_tokens': 121, 'input_token_details': {}, 'output_token_details': {}}),
 HumanMessage(content='**Caption Cri