In [11]:
import os
import hashlib
import networkx as nx
import matplotlib.pyplot as plt
from typing import List
from pydantic import BaseModel
from IPython.display import Image, Audio, display

import torch
from diffusers import StableDiffusionPipeline
from transformers import pipeline
import soundfile as sf

from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain.schema import OutputParserException

In [3]:
OUTPUT_DIR = "poc_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class Triple(BaseModel):
    head: str
    relation: str
    tail: str

class TripleOutput(BaseModel):
    triples: List[Triple]

parser = PydanticOutputParser(pydantic_object=TripleOutput)

TEMPLATE = """You are an information extraction system.
Extract all (head_entity, relation, tail_entity) triples from the text below.

{format_instructions}

Text: {input_text}
"""

prompt = PromptTemplate(
    template=TEMPLATE,
    input_variables=["input_text"],
    partial_variables={"format_instructions": parser.get_format_instructions()},
)

llm = Ollama(model="llama3")
chain = LLMChain(llm=llm, prompt=prompt)

def extract_triples(text: str, retries: int = 2) -> TripleOutput:
    for attempt in range(retries + 1):
        raw_output = chain.run({"input_text": text})
        try:
            return parser.parse(raw_output)
        except OutputParserException as e:
            print(f"Parse failed ({attempt+1}/{retries+1}): {e}")
            if attempt == retries:
                return TripleOutput(triples=[])
    return TripleOutput(triples=[])

  llm = Ollama(model="llama3")
  chain = LLMChain(llm=llm, prompt=prompt)


In [5]:
def build_graph_from_triples(parsed: TripleOutput):
    G = nx.DiGraph()
    for t in parsed.triples:
        G.add_node(t.head, type="Entity")
        G.add_node(t.tail, type="Entity")
        G.add_edge(t.head, t.tail, relation=t.relation)
    return G

def draw_graph(G, outpath):
    plt.figure(figsize=(10,6))
    pos = nx.spring_layout(G, seed=42)
    nx.draw_networkx_nodes(G, pos, node_size=900, node_color="skyblue")
    nx.draw_networkx_labels(G, pos, font_size=9)
    edge_labels = {(u,v): d.get("relation","") for u,v,d in G.edges(data=True)}
    nx.draw_networkx_edges(G, pos, arrows=True)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(outpath, dpi=150)
    plt.close()

In [6]:
def decide_modalities(G):
    sound_relations = {"sounds like", "emits", "produces", "plays", "sings"}
    visual_relations = {"occurs over", "is located on", "appears in", "shows", "illustrates"}

    want_audio, want_image = False, False
    audio_prompts, image_prompts = [], []

    for u, v, data in G.edges(data=True):
        rel = data.get("relation", "").lower().strip()
        if rel in sound_relations or "sound" in rel:
            want_audio = True
            audio_prompts.append(f"{u} {rel} {v}")
        if rel in visual_relations:
            want_image = True
            image_prompts.append(f"{u} {rel} {v}")

    entity_text = " ".join(G.nodes()).lower()
    if any(word in entity_text for word in ["storm","sky","forest","ocean","meadow","lighthouse","city","coast"]):
        want_image = True
        image_prompts.append(entity_text)

    return {"image": want_image, "audio": want_audio,
            "image_prompts": image_prompts, "audio_prompts": audio_prompts}

In [None]:
def init_diffusion(pipe_model="stabilityai/sd-turbo", device="auto"):
    if device == "cpu":
        pipe = StableDiffusionPipeline.from_pretrained(
            pipe_model,
            torch_dtype=torch.float32 
        )
        pipe = pipe.to("cpu")
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            pipe_model,
            torch_dtype=torch.float16 
        )
        pipe = pipe.to("CUDA")
        pipe.enable_attention_slicing()
    return pipe

def generate_image(pipe, prompt, outpath):
    seed = int(hashlib.sha1(prompt.encode()).hexdigest()[:8], 16) % (2**31)
    generator = torch.Generator(device=DEVICE).manual_seed(seed)
    image = pipe(prompt, height=256, width=256, num_inference_steps=10).images[0]

    image.save(outpath)
    return outpath


In [8]:
from transformers import pipeline
import soundfile as sf

musicgen = pipeline(
    "text-to-audio",
    model="facebook/musicgen-small",
    device="cpu"
)


def generate_audio(prompt: str, out_wav: str):
    audio = musicgen(prompt, max_new_tokens=256)
    audio_array = audio[0]["array"]
    sampling_rate = audio[0]["sampling_rate"]
    sf.write(out_wav, audio_array, sampling_rate)
    return out_wav


Device set to use cpu


In [9]:
def multimodal_poc(user_prompt):
    print("Extracting triples with Llama3...")
    triples = extract_triples(user_prompt)
    print("Extracted triples:", triples)

    G = build_graph_from_triples(triples)
    graph_path = os.path.join(OUTPUT_DIR, "kg_graph.png")
    draw_graph(G, graph_path)

    decisions = decide_modalities(G)
    print("Decisions:", decisions)

    results = {"graph": graph_path, "json": triples.model_dump(), "outputs": {}}

    if decisions["image"]:
        print("Generating image...")
        pipe = init_diffusion()
        img_path = os.path.join(OUTPUT_DIR, "generated.png")
        img_prompt = ", ".join(decisions["image_prompts"]) or user_prompt
        generate_image(pipe, img_prompt, img_path)
        results["outputs"]["image"] = img_path

    if decisions["audio"]:
        print("Generating audio...")
        wav_path = os.path.join(OUTPUT_DIR, "generated.wav")
        audio_prompt = ", ".join(decisions["audio_prompts"]) or user_prompt
        generate_audio(audio_prompt, wav_path)
        results["outputs"]["audio"] = wav_path

    return results

In [None]:
if __name__ == "__main__":
    prompt = "Show me a dramatic thunderstorm over a lighthouse and also provide how thunder sounds."
    out = multimodal_poc(prompt)

    print("Done. Results:", out)

    display(Image(filename=out["graph"]))
    if "image" in out["outputs"]:
        display(Image(filename=out["outputs"]["image"]))
    if "audio" in out["outputs"]:
        display(Audio(filename=out["outputs"]["audio"]))

  raw_output = chain.run({"input_text": text})


Extracting triples with Llama3...
Extracted triples: triples=[Triple(head='Show me a dramatic thunderstorm', relation='describes', tail='over a lighthouse'), Triple(head='how thunder sounds', relation='explains', tail='thunderstorms')]
Decisions: {'image': True, 'audio': False, 'image_prompts': ['show me a dramatic thunderstorm over a lighthouse how thunder sounds thunderstorms'], 'audio_prompts': []}
Generating image...


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


RuntimeError: Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, maia, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: DEVICE