In [1]:
from dotenv import load_dotenv
_ = load_dotenv()

In [2]:
from typing import TypedDict, Annotated, Sequence, List, Optional
import operator

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate

In [3]:
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(
    temperature=0.0,
    azure_deployment="gpt4o",
    openai_api_version="2023-07-01-preview",
)

In [4]:
class VideoInfo(BaseModel):
    video_id: str
    url: str
    relative_video_path: str
    subs: str
    transcript: str


class SegmentInfo(BaseModel):
    start_timestamp: str
    end_timestamp: str
    fps: float
    video_id: str


class LocalClue(BaseModel):
    """Local clues for a segment"""

    id: str = Field(description="LC1,LC2...")
    quote: str = Field(
        description="the quote from the transcript that was used to create this clue."
    )
    quote_timestamp_start: str = Field(
        description="the exact start timestamp of the quote."
    )
    quote_timestamp_end: str = Field(
        description="the exact end timestamp of the quote."
    )
    clue: str = Field(description="the main clue data")


class GlobalClue(BaseModel):
    """Global clues for a segment"""

    id: str = Field(description="GC1,GC2...")
    quote: str = Field(
        description="the quote from the transcript that was used to create this clue."
    )
    quote_timestamp_start: str = Field(
        description="the exact start timestamp of the quote."
    )
    quote_timestamp_end: str = Field(
        description="the exact end timestamp of the quote."
    )
    clue: str = Field(description="the main clue data.")
    relevance_to_segment: str = Field(
        description="why do you think this global clue is relevant to the segment you are working with right now."
    )


class LogicalInference(BaseModel):
    """Logical inferences for a segment"""

    id: str = Field(description="LI1,LI2,...")
    description: str = Field(description="A concise form of the logical inference.")
    details: str = Field(
        description="A verbose explanation of what insight about what happens in this segment should be made based on the clues that you found."
    )


class SegmentAnnotation(BaseModel):
    local_clues: list[LocalClue] = Field(
        description="Local clues are inside the segment in terms of timestamps."
    )
    global_clues: list[GlobalClue] = Field(
        description="Global clues are scattered across the entire transcript."
    )
    logical_inferences: list[LogicalInference] = Field(
        description="What can we infer about the topic, that the user is looking for in the video, can we make based on the clues inside this segment"
    )


class SegmentWithClueInfo(BaseModel):
    """
    Annotation for a video segment.
    """

    start_timestamp: str = Field(
        description="start timestamp of the segment in format HH:MM:SS.MS"
    )
    end_timestamp: str = Field(
        description="start timestamp of the segment in format HH:MM:SS.MS"
    )
    segment_annotation: SegmentAnnotation = Field(
        description="list of annotations for the segment"
    )


class VideoAnnotation(BaseModel):
    """
    Segments of a video.
    """

    segments: list[SegmentWithClueInfo] = Field(
        description="information about each segment"
    )

In [5]:
# 2. Create the state


class AgentState(TypedDict):
    task: str
    search_queries: List[str]
    video_ids: List[str]
    video_infos: List[VideoInfo]
    clip_text_prompts: List[str]
    segment_infos: List[SegmentInfo]
    clues: List[str]
    annotations: List[str]

In [6]:
# 3. Set prompts

GEN_QUERIES_PROMPT = (
    "You a helping the user to find a very large and diverse set of videos on a video hosting service.",
    "A user will only describe which videos they are looking for and how many queries they need.",
)

# prompt='I want to find instructional videos about how to do squats.',
# num_queries_prompt = f'I need {num_queries} queries'

EXTRACT_CLUES_PROMPT = """You are a highly intelligent data investigator.  
You take unstructured damaged data and look for clues that could help restore the initial information
and extract important insights from it.
You are the best one for this job in the world because you are a former detective. 
You care about even the smallest details, and your guesses about what happened in the initial file
even at very limited inputs are usually absolutely right.  
You use deductive and inductive reasoning at the highest possible quality.

#YOUR TODAY'S JOB
The user needs to learn about what happens in a specific segment of a video file. Your job is to help the user by providing clues that would help the user make the right assumption.
The user will provide you with: 
1. Instructions about what kind of information the user is trying to obtain.
2. A list of time codes of the segments in format "<HH:MM:SS.ms>-<HH:MM:SS.ms>". All the provided segment of the video contain what the user is looking for, but other parts of the video might have different content.
3. A transcript of the *full video* in format of "<HH.MM.SS>\\n<text>"

Your task:
1. Read the transcript.
2. Provide the clues in a given format.
3. Provied any other info requested by the user.

#RULES
!!! VERY IMPORTANT !!!
1. Rely only on the data provided in the transcript. Do not improvise. All the quotes and corresponding timestamps must be taken from the transcript. Quote timestamps must be taken directly from the transcript.
2. Your job is to find the data already provided in the transcript.
3. Analyze every segment. Only skip a segment if there is no information about it in the trascript.
4. For local clues, make sure that the quotes that you provide are located inside the segment. To do this, double check the timestamps from the transcript and the segment.
5. For all clues, make sure that the quotes exactly correspond to the timestamps that you provide.
6. When making clues, try as much as possible to make them describe specifically what is shown in the segment.
7. Follow the format output.
8. Be very careful with details. Don't generalize. Always double check your results.

Please, help the user find relevant clues to reconstruct the information they are looking for, for each provided segment.

WHAT IS A CLUE: A *clue*, in the context of reconstructing narratives from damaged data, 
is a fragment of information extracted from a corrupted or incomplete source that provides 
insight into the original content. These fragments serve as starting points for inference 
and deduction, allowing researchers to hypothesize about the fuller context or meaning of 
the degraded material. The process of identifying and interpreting clues involves both objective analysis of the 
available data and subjective extrapolation based on domain knowledge, contextual understanding, 
and logical reasoning.

Here is what the user expects to have from you:
1. *Local clues* that would help the user undestand how the thing they are looking for happens inside the segment. Local clues for a segment are generated from quotes inside a specific segment.
2. *Global clues* that would help the user understand how the thing they are looking for happens inside the segment. Global clues for a segment are generated from quotes all around the video, but are very relevant to the specific that they are provided for.
3. *Logical inferences* that could help the user understand how the thing they are looking for happens inside the segment. Logical inferences for a segment are deducted from local and global clues for this segment.

!!!IT IS EXTREMELY IMPORTANT TO DELIVER ALL THREE THINGS!!!

        Good local clues examples: [
      {
        "id": "LC1",
        "timestamp": "00:00:19",
        "quote": "exercises do them wrong and instead of",
        "clue": "This phrase introduces the concept of incorrect exercise form, setting the stage for a demonstration of improper technique."
      },
      {
        "id": "LC2",
        "timestamp": "00:00:21",
        "quote": "growing nice quads and glutes you'll",
        "clue": "Mentions the expected benefits of proper squats (muscle growth), implying that these benefits won't be achieved with incorrect form."
      },
      {
        "id": "LC3",
        "timestamp": "00:00:22",
        "quote": "feel aches and pains in your knees your",
        "clue": "Directly states negative consequences of improper form, strongly suggesting that this segment demonstrates incorrect technique."
      },
      {
        "id": "LC4",
        "timestamp": "00:00:24",
        "quote": "lower back and even your shoulders",
        "clue": "Continuation of LC3, emphasizing multiple areas of potential pain from improper form."
      },
      {
        "id": "LC5",
        "timestamp": "00:00:26",
        "quote": "let's see how to do it correctly",
        "clue": "This phrase suggests a transition is about to occur. The incorrect form has been shown, and correct form will follow."
      }
    ]

    Double check that the timestamp and the quote that you provide exactly correspond to what you found in the transcript.
    For example, if the transcript says:
    "00:05:02
    he took the glasses
    00:05:04
    and gave them to me"
    Then a GOOD output will be:
    - timestamp: 00:05:03
    - quote: "he took the glasses and gave them to me"
    And a BAD output would be:
    - timestamp: 00:04:02
    - quote: "he gave me the glasses"

    Good global clues examples: [
      {
        "id": "GC1",
        "timestamp": "00:01:15",
        "quote": "Before we dive into specific techniques, let's talk about safety.",
        "clue": "Introduces the theme of safety in squatting.",
        "relevance_to_segment": "This earlier emphasis on safety provides context for why proper depth is important and why it's being addressed in our segment. It connects to the fear of knee pain mentioned in LC3."
      },
      {
        "id": "GC2",
        "timestamp": "00:02:30",
        "quote": "Squatting is a fundamental movement pattern in everyday life.",
        "clue": "Emphasizes the importance of squats beyond just exercise.",
        "relevance_to_segment": "This broader context heightens the importance of learning proper squat depth as demonstrated in our segment. It suggests that the techniques shown have applications beyond just gym workouts."
      },
      {
        "clue_id": "GC3",
        "timestamp": "00:05:20",
        "quote": "If you have existing knee issues, consult a physician before attempting deep squats.",
        "clue": "Provides a health disclaimer related to squat depth.",
        "relevance_to_segment": "While this comes after our segment, it's relevant because it addresses the concern about knee pain mentioned in LC3. It suggests that the demonstration in our segment is generally safe but acknowledges individual variations."
      },
      {
        "clue_id": "GC4",
        "timestamp": "00:06:45",
        "quote": "Proper depth ensures full engagement of your quadriceps and glutes.",
        "clue": "Explains the benefit of correct squat depth.",
        "relevance_to_segment": "This later explanation provides justification for the depth guideline given in LC4. It helps viewers understand why the demonstrated technique is important."
      },
      {
        "clue_id": "GC5",
        "timestamp": "00:00:30",
        "quote": "Today, we'll cover squat variations for beginners to advanced lifters.",
        "clue": "Outlines the scope of the entire video.",
        "relevance_to_segment": "This early statement suggests that our segment, focusing on proper depth, is part of a comprehensive guide. It implies that the demonstration might be adaptable for different skill levels."
      }
    ]
    Double check that the timestamp and the quote that you provide exactly correspond to what you found in the transcript.
    For example, if the transcript says:
    "00:05:02
    he took the glasses
    00:05:04
    and gave them to me"
    Then a GOOD output will be:
    - timestamp: 00:05:03
    - quote: "he took the glasses and gave them to me"
    And a BAD output would be:
    - timestamp: 00:04:02
    - quote: "he gave me the glasses"
    

    Good logical inference examples:
    [
      {
        "id": "LI1",
        "description": "Primary Demonstration of Heel Lift",
        "details": "Given that GC1-GC3 describe the 'most common mistake' as heels lifting off the ground, and this description immediately precedes our segment, it's highly probable that this is the primary error being demonstrated. This is further supported by the segment's focus on incorrect form (LC1-LC4)."
      },
      {
        "id": "LI2",
        "description": "Multiple Error Demonstration",
        "details": "While heel lift is likely the primary focus, the mention of multiple pain points (knees, lower back, shoulders in LC3-LC4) suggests that the demonstrator may be exhibiting several forms of incorrect technique simultaneously. This comprehensive 'what not to do' approach would be pedagogically effective."
      },
      {
        "id": "LI3",
        "description": "Possible Inclusion of 'Butt Wink'",
        "details": "Although 'butt wink' is mentioned after our segment (GC4-GC6), its connection to back pain (which is mentioned in LC4) raises the possibility that this error is also present in the demonstration. The instructor may be showing multiple errors early on, then breaking them down individually later."
      },
      {
        "id": "LI4",
        "description": "Segment Placement in Overall Video Structure",
        "details": "The segment's position (starting at 00:00:19) and the phrase 'let's see how to do it correctly' (LC5) at the end suggest this is an early, foundational part of the video. It likely serves to grab attention by showing common mistakes before transitioning to proper form instruction."
      },
      {
        "id": "LI5",
        "description": "Intentional Exaggeration of Errors",
        "details": "Given the educational nature of the video, it's plausible that the demonstrator is intentionally exaggerating the incorrect form. This would make the errors more obvious to viewers and enhance the contrast with correct form shown later."
      }
    ]
"""


GEN_ANNOTATIONS_PROMPT = """You are a helpful assistant that performs high quality data investigation and transformation.
  You will be given a JSON object with clues and other helpful information about what's going on 
  in a specific part of a video file. This part is called a segment. Your job is to:
  1. Read this JSON object carefully
  2. Answer user's questions about this segment
  3. Provide the answer as a JSON object in a schema provided by the user
  Important rules:
  1. You can only rely on data presented in a provided JSON object. Don't improvise.
  2. Follow user's request carefully.
  3. Don't rush to deliver the answer. Take some time to think. Make a deep breath. Then start writing.
  4. If you want to output field as empty (null), output it as JSON null (without quotes), not as a string "null". 
—> GOOD EXAMPLES:
  "wrong":"Knees caving in: This can stress the knees and reduce effectiveness"
  "correction":"Focus on keeping knees aligned with your toes."
  "wrong":"Rounding the back: This increases the risk of back injuries"
  "correction":"Keep your chest up and maintain a neutral spine throughout the movement."
  "wrong":"Heels are lifting off the ground: this shifts the weight forward, reducing stability"
  "correction":" Keep your weight on your heels and press through them as you rise."
  "right":"Chest and shoulders: The chest is up, and the shoulders are back, maintaining an upright torso."
  "correction":null
—> BAD EXAMPLES:
  "wrong":"knees"
  "correction":"fix knees"
  "wrong":"back looks funny"
  "correction":"make back better"
  "wrong":"feet are doing something"
  "correction":"feet should be different"
  "right":"arms"
  "correction":"arms are fine i think"
—> BAD EXAMPLES END HERE
"""

In [7]:
import scrapetube
import yt_dlp
from datetime import datetime
from pathlib import Path
from collections import defaultdict
from datagen.core.sub_utils import vtt_to_txt
from datagen.detect_segments import get_segments
import torch
from transformers import AutoModel, AutoProcessor
import pandas as pd
from tsmoothie.smoother import LowessSmoother

In [8]:
import decord
import math
import numpy as np

# decord.bridge.set_bridge("torch")


class VideoInferenceDataset(torch.utils.data.IterableDataset):
    def __init__(self, video_infos: List[VideoInfo], local_root: Path):
        super(VideoInferenceDataset).__init__()

        self.video_infos = video_infos
        self.local_root = local_root
        self.frame_generator = self.get_frame_generator(video_infos, local_root)

    @staticmethod
    def get_frame_generator(video_infos, local_root: Path):

        for video_idx, video_info in enumerate(video_infos):
            video_path = local_root.joinpath(video_info.relative_video_path)
            vr = decord.VideoReader(str(video_path))
            num_frames = len(vr)
            fps = vr.get_avg_fps()
            frame_indices = range(0, num_frames, round(fps))

            for frame_idx in frame_indices:
                # print(f"Frame idx {frame_idx}")
                frame = vr[frame_idx].asnumpy()
                yield {
                    "frame": frame,
                    "frame_idx": frame_idx,
                    "video_id": video_idx,
                }

    def __next__(self):
        return next(self.frame_generator)

    def __iter__(self):
        return self

In [9]:
import time
import math

# 4. Create nodes


def gen_queries_node(state: AgentState):
    class QueryList(BaseModel):
        """A list of queries to find videos on a video hosting service"""

        search_queries: list[str] = Field(default=None, description="a list of queries")

    messages = [
        SystemMessage(content=str(GEN_QUERIES_PROMPT)),
        HumanMessage(content=state["task"]),
    ]

    model = llm.with_structured_output(QueryList)
    response: QueryList = model.invoke(messages)

    return {"search_queries": response.search_queries[:2]}


def get_video_ids_node(state: AgentState):

    queries = state["search_queries"]
    videos_per_query = 1
    sleep = 0
    sort_by = "relevance"
    results_type = "video"
    only_creative_commons = False

    video_ids = set()
    for query in queries:
        for video in scrapetube.get_search(
            query=query,
            limit=videos_per_query,
            sleep=sleep,
            sort_by=sort_by,
            results_type=results_type,
        ):
            video_ids.add(video["videoId"])
    video_ids = list(video_ids)

    if only_creative_commons:
        video_ids_cc = []
        for i in video_ids:
            YDL_OPTIONS = {
                "quiet": True,
                "simulate": True,
                "forceurl": True,
            }
            with yt_dlp.YoutubeDL(YDL_OPTIONS) as ydl:
                info = ydl.extract_info(f"youtube.com/watch?v={i}", download=False)
            if "creative commons" in info.get("license", "").lower():
                video_ids_cc.append(i)
        video_ids = video_ids_cc

    return {"video_ids": video_ids}


def download_node(state: AgentState):

    LOCAL_ROOT = Path("./tmp/agent_squats").resolve()
    video_dir = LOCAL_ROOT / "videos"
    sub_dir = LOCAL_ROOT / "subs"

    discard_path = LOCAL_ROOT / "videos_without_subs"
    discard_path.mkdir(parents=True, exist_ok=True)

    video_ids = state["video_ids"]

    downloaded_video_ids = [video_path.stem for video_path in video_dir.glob("*.mp4")]
    downloaded_video_ids += [
        video_path.stem for video_path in discard_path.glob("*.mp4")
    ]

    print(f"Downloaded video ids: {downloaded_video_ids}")

    only_with_transcripts = True

    YDL_OPTIONS = {
        "writeautomaticsub": True,
        "subtitleslangs": ["en"],
        "subtitlesformat": "vtt",
        "overwrites": False,
        "format": "mp4",
        "outtmpl": {
            "default": video_dir.as_posix() + "/%(id)s.%(ext)s",
            "subtitle": sub_dir.as_posix() + "/%(id)s.%(ext)s",
        },
    }

    video_infos = []

    with yt_dlp.YoutubeDL(YDL_OPTIONS) as ydl:
        for video_id in video_ids:
            url = f"https://www.youtube.com/watch?v={video_id}"

            if video_id not in downloaded_video_ids:
                try:
                    ydl.download(url)
                except Exception as e:
                    print(datetime.now(), f"Error at video {video_id}, skipping")
                    print(datetime.now(), e)
                    continue

            video_path = Path(ydl.prepare_filename({"id": video_id, "ext": "mp4"}))
            sub_path = Path(
                ydl.prepare_filename(
                    {"id": video_id, "ext": "en.vtt"}, dir_type="subtitle"
                )
            )

            with sub_path.open("r") as f:
                subs = f.read()

            transcript = vtt_to_txt(sub_path)

            video_info = VideoInfo(
                video_id=video_id,
                url=url,
                relative_video_path=video_path.relative_to(LOCAL_ROOT).as_posix(),
                subs=subs,
                transcript=transcript,
            )

            video_infos.append(video_info)

    if only_with_transcripts:
        filtered_video_infos = []
        for video_info in video_infos:
            if video_info.transcript:
                filtered_video_infos.append(video_info)
            else:
                video_path = LOCAL_ROOT / video_info.video_path
                video_path.rename(discard_path / video_path.name)
        video_infos = filtered_video_infos

    return {"video_infos": video_infos}


def detect_segments_node(state: AgentState):

    LOCAL_ROOT = Path("./tmp/agent_squats").resolve()

    clip_text_prompts = state["clip_text_prompts"]
    video_infos = state["video_infos"]

    CLIP_MODEL_ID = "google/siglip-so400m-patch14-384"

    model = AutoModel.from_pretrained(CLIP_MODEL_ID).to("cuda")
    processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)

    dataset = VideoInferenceDataset(video_infos, LOCAL_ROOT)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        num_workers=1,
        batch_size=12,
        pin_memory=True,
        # worker_init_fn=worker_init_fn,
    )
    dataloader = iter(dataloader)

    smoother = LowessSmoother(smooth_fraction=0.02, iterations=1)

    clip_results_dict = defaultdict(list)

    print("Init model complete")

    batch_counter = 0
    MAX_BATCHES = 50

    while batch_counter < MAX_BATCHES:
        batch_counter += 1
        try:
            start_time = time.time()
            batch = next(dataloader)
            # print(f"Fetch time: {time.time() - start_time:.2f} seconds")
        except StopIteration:
            break

        inputs = processor(
            images=batch["frame"],
            text=clip_text_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        outputs = model(**inputs)

        logits = outputs.logits_per_image
        probs = torch.nn.functional.sigmoid(logits).detach().cpu().numpy()

        for video_idx, frame_idx, prob in zip(
            batch["video_id"], batch["frame_idx"], probs
        ):
            # print(type(video_id.item()), type(frame_idx.item()), type(prob.item()))
            video_id = video_infos[video_idx.item()].video_id

            clip_results_dict["video_id"].append(video_id)
            clip_results_dict["frame_idx"].append(frame_idx.item())
            clip_results_dict["probs"].append(prob.item())

    print("All frames processed")
    clip_results = pd.DataFrame(clip_results_dict)
    print("Dataframe created")
    print(clip_results)

    max_gap_seconds = 1
    fps_sampling = 1
    min_prob = 0.1
    min_segment_seconds = 3
    fps = 25

    segment_infos = []
    for video_id, video_clip_results in clip_results.groupby("video_id"):
        probs = video_clip_results["probs"].values
        probs = smoother.smooth(probs).smooth_data[0]
        segments_start_end = get_segments(
            probs,
            max_gap=round(max_gap_seconds * fps_sampling),
            min_prob=min_prob,
            min_segment=round(min_segment_seconds * fps_sampling),
        )

        print(f"Segments for video {video_id}: {segments_start_end}")

        sec2ts = lambda s: time.strftime(
            f"%H:%M:%S.{round((s%1)*1000):03d}", time.gmtime(s)
        )

        for start, end in segments_start_end:
            segment_infos.append(
                SegmentInfo(
                    start_timestamp=sec2ts(start),
                    end_timestamp=sec2ts(end),
                    fps=fps,
                    video_id=video_id,
                )
            )

    return {"segment_infos": segment_infos}


def extract_clues_node(state: AgentState):

    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", EXTRACT_CLUES_PROMPT),
            (
                "user",
                "Segment timecodes: {{ segment_timecodes }}\nTranscript: {{ transcript }}",
            ),
        ],
        template_format="jinja2",
    )

    model = prompt_template | llm.with_structured_output(VideoAnnotation)

    segment_infos_dict = defaultdict(list)
    for segment_info in state["segment_infos"]:
        segment_infos_dict[segment_info.video_id].append(segment_info)

    video_infos_dict = {
        video_info.video_id: video_info for video_info in state["video_infos"]
    }

    clues = []

    for video_id, segment_infos in segment_infos_dict.items():
        transcript = video_infos_dict[video_id].transcript
        segment_infos_chunks = [
            segment_infos[i : i + 5] for i in range(0, len(segment_infos), 5)
        ]

        for chunk in segment_infos_chunks:
            video_annotation: VideoAnnotation = model.invoke(
                {
                    "segment_timecodes": "\n".join(
                        [f"{s.start_timestamp}-{s.end_timestamp}" for s in chunk]
                    ),
                    "transcript": transcript,
                }
            )
            clues.extend(video_annotation.segments)

    return {"clues": clues}


def gen_annotations_node(state: AgentState):
    class SegmentFeedback(BaseModel):
        right: Optional[str] = Field(description="what was right in the performance")
        wrong: Optional[str] = Field(description="what was wrong in the performance")
        correction: Optional[str] = Field(
            description="how and in what ways it the performance could be improved"
        )

    # The segment timestamps are taken from the provided information.
    class SegmentCompleteAnnotation(BaseModel):
        squats_probability: Optional[str] = Field(
            description="how high is the probability that the person is doing squats in the segment: low, medium, high, unknown(null)"
        )
        squats_technique_correctness: Optional[str] = Field(
            description="correctness of the squat technique."
        )
        squats_feedback: Optional[SegmentFeedback] = Field(
            description="what was right and wrong in the squat perfomance in the segment. When the technique is incorrect, provide instructions how to correct them."
        )

    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", GEN_ANNOTATIONS_PROMPT),
            ("user", "Clues: {{ clues }}"),
        ],
        template_format="jinja2",
    )

    model = prompt_template | llm.with_structured_output(SegmentCompleteAnnotation)

    clues = state["clues"]

    annotations = []
    for clue in clues:
        segment_annotation: SegmentCompleteAnnotation = model.invoke(
            {"clues": clue.json()}
        )

        annotations.append(segment_annotation.json())

    print(annotations)

    return {"annotations": annotations}

In [10]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
import operator
from langgraph.checkpoint.memory import MemorySaver

from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage, ChatMessage

memory = MemorySaver()
# memory = SqliteSaver.from_conn_string(":memory:")

In [11]:
builder = StateGraph(AgentState)

builder.add_node("generate_queries", gen_queries_node)
builder.add_node("get_video_ids", get_video_ids_node)
builder.add_node("download", download_node)
builder.add_node("detect_segments", detect_segments_node)
builder.add_node("extract_clues", extract_clues_node)
builder.add_node("gen_annotations", gen_annotations_node)

builder.set_entry_point("generate_queries")

# builder.add_conditional_edges(
#     "generate",
#     should_continue,
#     {END: END, "reflect": "reflect"}
# )

builder.add_edge("generate_queries", "get_video_ids")
builder.add_edge("get_video_ids", "download")
builder.add_edge("download", "detect_segments")
builder.add_edge("detect_segments", "extract_clues")
builder.add_edge("extract_clues", "gen_annotations")
builder.add_edge("gen_annotations", END)

graph = builder.compile(checkpointer=memory)

In [12]:
thread = {"configurable": {"thread_id": "1"}}
for s in graph.stream(
    {
        "task": "i wanna teach people how to do squats",
        "clip_text_prompts": ["person doing squats"],
    },
    thread,
):
    if "download" in s:
        print("dowload happened")
    elif "extract_clues" in s:
        print("extract_clues happened")
    else:
        print(s)

{'generate_queries': {'search_queries': ['how to do squats', 'squat exercise tutorial']}}
{'get_video_ids': {'video_ids': ['xqvCmoLULNY', 'IB_icWRzi4E']}}
Downloaded video ids: ['IB_icWRzi4E', 'xqvCmoLULNY']
dowload happened
Init model complete
All frames processed
Dataframe created
        video_id  frame_idx         probs
0    xqvCmoLULNY          0  2.199925e-08
1    xqvCmoLULNY         24  1.503990e-01
2    xqvCmoLULNY         48  1.242190e-01
3    xqvCmoLULNY         72  1.302760e-01
4    xqvCmoLULNY         96  1.310861e-01
..           ...        ...           ...
220  IB_icWRzi4E       4275  2.498681e-07
221  IB_icWRzi4E       4300  3.288528e-07
222  IB_icWRzi4E       4325  3.445720e-07
223  IB_icWRzi4E       4350  3.333991e-07
224  IB_icWRzi4E       4375  2.660451e-07

[225 rows x 3 columns]
Segments for video IB_icWRzi4E: [(0, 5), (9, 24), (29, 45), (49, 53), (62, 66), (103, 109), (138, 147)]
Segments for video xqvCmoLULNY: [(1, 44)]
{'detect_segments': {'segment_infos': [Seg

In [13]:
graph.get_state(thread).values

{'task': 'i wanna teach people how to do squats',
 'search_queries': ['how to do squats', 'squat exercise tutorial'],
 'video_ids': ['xqvCmoLULNY', 'IB_icWRzi4E'],
 'video_infos': [VideoInfo(video_id='xqvCmoLULNY', url='https://www.youtube.com/watch?v=xqvCmoLULNY', relative_video_path='videos/xqvCmoLULNY.mp4', subs="WEBVTT\nKind: captions\nLanguage: en\n\n00:00:00.160 --> 00:00:01.829 align:start position:0%\n \nlet's<00:00:00.399><c> learn</c><00:00:00.560><c> how</c><00:00:00.719><c> to</c><00:00:00.880><c> properly</c><00:00:01.280><c> perform</c><00:00:01.760><c> a</c>\n\n00:00:01.829 --> 00:00:01.839 align:start position:0%\nlet's learn how to properly perform a\n \n\n00:00:01.839 --> 00:00:02.790 align:start position:0%\nlet's learn how to properly perform a\nsquat\n\n00:00:02.790 --> 00:00:02.800 align:start position:0%\nsquat\n \n\n00:00:02.800 --> 00:00:04.470 align:start position:0%\nsquat\nstart<00:00:03.120><c> with</c><00:00:03.199><c> your</c><00:00:03.360><c> feet</c><00