In [1]:
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage, SystemMessage
from langchain_anthropic import ChatAnthropic
from graphdatascience import GraphDataScience
from getpass import getpass
import time
import numpy as np
import nltk

  from .autonotebook import tqdm as notebook_tqdm


# Set up connection information

In [2]:
anthropic_api_key = getpass()

 ········


In [4]:
neo4j_password = getpass()

 ········


In [5]:
neo4j_uri = "neo4j+s://2fe3bf28.databases.neo4j.io"
neo4j_user = "neo4j"
gds = GraphDataScience(neo4j_uri, auth=(neo4j_user, neo4j_password))

I was curious to experiment with the Claude 3 models for this, but other models should work fine.

In [8]:
chat = ChatAnthropic(temperature=0, model="claude-3-sonnet-20240229", anthropic_api_key=anthropic_api_key)

# Get movie data

In [6]:
movie_info = gds.run_cypher(
    """MATCH (m:Movie) 
       WHERE NOT EXISTS {(m)-[:HAS_THEME]->()} 
       AND trim(m.overview) <> ''
       RETURN m.tmdbId AS tmdbId, m{.title, .overview} AS description""" )

In [7]:
movie_info

Unnamed: 0,tmdbId,description
0,1234165,"{'title': 'The Buy Bust Queen', 'overview': 'T..."
1,1234196,"{'title': 'Sihinayaki Adare', 'overview': 'Thr..."
2,1234203,"{'title': 'Vertical Money', 'overview': 'The f..."
3,1234240,"{'title': 'God's Lonely Man', 'overview': 'A g..."
4,1234243,{'title': 'Galym Kaliakbarov: When He Grows Up...
...,...,...
2518,1256551,"{'title': 'Agastya – Chapter 1', 'overview': '..."
2519,1256574,"{'title': 'Eye of the Fen', 'overview': 'Far i..."
2520,1256587,"{'title': 'Mushrooms', 'overview': 'n Mushroom..."
2521,1256588,{'title': 'How I Trafficked $5 Million Worth o...


# Create theme extraction pipeline

In [10]:
system_message = SystemMessage(
    content="""You are a movie expert. 
    You are given the tile and overview of the plot of a movie.
    Summarize the most memorable themes, settings, and public figures in the movie into a list of up to eight one-to-two word phrases. 
    Only include the names of people if the person is a famous public figure.
    Prioritize any phrases that appear in the movie's title.
    You can provide fewer than eight phrases.
    Return the phrases as a pipe separated list. 
    Return only the list without a heading.""")

final_prompt = ChatPromptTemplate.from_messages(
    [system_message,
     ("human", """title: {title}
     overview: {overview}""")])

summary_chain = final_prompt | chat

In [11]:
movie_info.loc[0, 'description']

{'title': 'The Buy Bust Queen',
 'overview': 'This tale centers on seven exceptional women who defy gender norms, showcasing their unwavering dedication and resilience in joining an institution traditionally dominated by men.  Their commitment stems from a lifetime of virtues, forming an unshakable foundation that cannot be easily replicated.'}

In [12]:
summary_chain.invoke(movie_info.loc[0, 'description'])

AIMessage(content='Exceptional women|Gender norms|Resilience|Male-dominated institution|Lifetime virtues')

# Extract themes and send to Neo4j
Create a node key constraint for the Theme nodes that requires that all Theme nodes have a unique value for the description property.

In [9]:
gds.run_cypher("""CREATE CONSTRAINT theme_node_key IF NOT EXISTS FOR (n:Theme) REQUIRE n.description IS NODE KEY""")

In [13]:
def process_theme_batch(start_index, end_index):
    rows = movie_info.iloc[start_index: end_index,:].copy()
    results = summary_chain.batch(rows['description'].to_list())
    rows['themes'] = [result.content.split("|") for result in results]
    gds.run_cypher("""
    UNWIND $data as d
    MATCH (m:Movie {tmdbId:d['tmdbId']})
    UNWIND d['themes'] AS theme
    MERGE (t:Theme {description:trim(theme)})
    WITH t, m
    MERGE (m)-[:HAS_THEME {prompt: "v2"}]->(t)
    """,
                   {"data": rows[['tmdbId', 'themes']].to_dict("records")})

In [14]:
for i in range(0, int(movie_info.shape[0]/5) + 1):
    process_theme_batch(i*5,(i+1)*5)
    if (i + 1) % 5 == 0:
        print(f"Finished row {(i+1)*5}")
    time.sleep(8)

Finished row 25
Finished row 50
Finished row 75
Finished row 100
Finished row 125
Finished row 150
Finished row 175
Finished row 200
Finished row 225
Finished row 250
Finished row 275
Finished row 300
Finished row 325
Finished row 350
Finished row 375
Finished row 400
Finished row 425
Finished row 450
Finished row 475
Finished row 500
Finished row 525
Finished row 550
Finished row 575
Finished row 600
Finished row 625
Finished row 650
Finished row 675
Finished row 700
Finished row 725
Finished row 750
Finished row 775
Finished row 800
Finished row 825
Finished row 850
Finished row 875
Finished row 900
Finished row 925
Finished row 950
Finished row 975
Finished row 1000
Finished row 1025
Finished row 1050
Finished row 1075
Finished row 1100
Finished row 1125
Finished row 1150
Finished row 1175
Finished row 1200
Finished row 1225
Finished row 1250
Finished row 1275
Finished row 1300
Finished row 1325
Finished row 1350
Finished row 1375
Finished row 1400
Finished row 1425
Finished row 145