In [22]:
# import faiss
import torch
from glob import glob
from transformers import Swinv2Model, AutoImageProcessor
from PIL import Image

# build case base
cb: dict[str, torch.Tensor] = {}
files = glob("../data/microtexts-images/*.png")

model = Swinv2Model.from_pretrained("../pretrained_model")
processor = AutoImageProcessor.from_pretrained("../pretrained_model")

def embedd(image: torch.Tensor):
    # processed = processor(image, return_tensors="pt")
    return model(image).pooler_output.squeeze()

def embedd_file(file: str):
    image = Image.open(file).convert("RGB")
    return embedd(processor(image, return_tensors="pt").pixel_values)

for f in files:
    print(f"processing {f}")
    graph_id = f.split("/")[-1].split(".")[0]
    image = Image.open(f).convert("RGB")
    cb[graph_id] = embedd(processor(image, return_tensors="pt").pixel_values)

cb

processing ../data/microtexts-images/nodeset6380.png
processing ../data/microtexts-images/nodeset6442.png
processing ../data/microtexts-images/nodeset6395.png
processing ../data/microtexts-images/nodeset6373.png
processing ../data/microtexts-images/nodeset6454.png
processing ../data/microtexts-images/nodeset6420.png
processing ../data/microtexts-images/nodeset6433.png
processing ../data/microtexts-images/nodeset6369.png
processing ../data/microtexts-images/nodeset6440.png
processing ../data/microtexts-images/nodeset6463.png
processing ../data/microtexts-images/nodeset6441.png
processing ../data/microtexts-images/nodeset6456.png
processing ../data/microtexts-images/nodeset6381.png
processing ../data/microtexts-images/nodeset6410.png
processing ../data/microtexts-images/nodeset6423.png
processing ../data/microtexts-images/nodeset6383.png
processing ../data/microtexts-images/nodeset6384.png
processing ../data/microtexts-images/nodeset6392.png
processing ../data/microtexts-images/nodeset64

{'nodeset6380': tensor([ 4.2988e-01,  3.1478e-01, -1.1245e-01, -1.3365e-01,  7.8521e-02,
         -3.9641e-01, -5.0684e-01, -2.6999e-01, -1.4254e-01, -1.5841e-01,
          5.1893e-03, -5.1591e-01, -1.7430e-01, -3.4253e-01, -4.2500e-01,
         -5.1145e-01,  8.1728e-01, -2.1423e-01,  1.8318e-01,  7.7936e-02,
          1.8293e-01,  5.4851e-01,  7.3248e-01, -5.6786e-01, -6.2792e-02,
          1.5863e-02, -5.7833e-01,  1.1929e+00, -6.5258e-01,  6.0501e-01,
         -6.4524e-01,  8.8063e-01, -1.0201e+00,  7.0494e-01, -1.0542e+00,
          5.9044e-01,  1.4098e-01, -1.8210e-01,  9.9270e-01,  4.7395e-01,
         -4.3653e-02,  6.4704e-01, -1.5481e-01, -2.7071e-01,  6.9613e-02,
         -9.0648e-01, -1.3033e+00,  5.1653e-01,  6.7703e-01,  2.6006e-01,
          6.7551e-01,  8.6353e-01,  1.0923e+00, -9.1909e-01,  6.7662e-01,
          2.3897e-01,  4.8081e-02,  3.9897e-01,  2.6406e-01, -9.1345e-01,
         -7.8845e-01,  9.4291e-03, -1.5705e-01,  2.5293e-01,  1.1188e+00,
          3.0941e-01, -

In [12]:
# from data/kialo_dataset.ipynb
# converts graph file to image
import subprocess
from pybars import Compiler
from glob import glob
import arguebuf as ab
from tqdm import tqdm
import os

compiler = Compiler()

def _list(_, options, items):
    result = []
    for item in items:
        result.append(options['fn'](item))
        result.append('\n')
    return result

def get_color(node_label):
    if node_label == "Support":
        return "green"
    elif node_label == "Attack":
        return "red"
    return "blue"


def export_graph(inp: ab.Graph, file: str, return_str = False):
    source = """
        digraph "" {
        nodesep=0.02
        layersep=0.02
        ranksep=0.02
        node [height=0.2,
            label="",
            style=filled,
            width=0.2,
            shape=ellipse,
            penwidth=0,
            color=blue
        ];
        sep=-10
        edge [arrowhead=none,
            style=tapered
        ];
        {{#list nodes}}"{{id}}" [color="{{color}}"] {{/list}}
        {{#list edges}}"{{source}}" -> "{{target}}" {{/list}}
    }
    """
    template = compiler.compile(source)
    helpers = {
        'list': _list,
    }
    output = template({'nodes': [{"id": node.id, "color": get_color(node.label)} for node in inp.nodes.values()], 'edges': [{"source": edge.source.id, "target": edge.target.id} for edge in inp.edges.values()]}, helpers=helpers)
    if return_str:
        return output
    subprocess.run(["dot", "-Tpng", "-o", file], input=output.encode())

EXT = "json"
FOLDER_NAME = "microtexts-retrieval-complex"

output = f"{FOLDER_NAME}-images/"
files = glob(f"{FOLDER_NAME}/*.{EXT}")

os.makedirs(output, exist_ok=True)

for file in tqdm(files):
    inp = ab.load.file(file)
    export_graph(inp, output + file.split("/")[-1].replace(f".{EXT}", ".png"))

100%|██████████| 15/15 [00:00<00:00, 63.58it/s]


In [98]:
from arg_services.cbr.v1beta import retrieval_pb2, retrieval_pb2_grpc, model_pb2
from arg_services.nlp.v1 import nlp_pb2
import grpc

PATH = "../data/microtexts"
QUERY_PATH = "./microtexts-retrieval-complex"
def get_text(graph_id: str):
    graph_text_path = f"{PATH}/{graph_id}.txt"
    with open(graph_text_path) as f:
        return f.read()

def get_text_query(query: ab.Graph):
    texts = [node.label for node in query.nodes.values() if node.label != "Support" and node.label != "Attack"]
    return " ".join(texts)

stub = retrieval_pb2_grpc.RetrievalServiceStub(grpc.insecure_channel("localhost:50200"))
files = glob("../data/microtexts/*.json")
cases = {f.split("/")[-1].split(".")[0]: ab.dump.protobuf(ab.load.file(f)) for f in files}
cases = {k: model_pb2.AnnotatedGraph(graph=v, text=get_text(k)) for k, v in cases.items()}

def retrieve_mac(queryfile: str):
    query = ab.load.file(queryfile)
    query = model_pb2.AnnotatedGraph(graph=ab.dump.protobuf(query), text=get_text_query(query))
    config = nlp_pb2.NlpConfig(
        language="en",
        spacy_model="en_core_web_lg",
        similarity_method=nlp_pb2.SimilarityMethod.SIMILARITY_METHOD_COSINE,
    )
    return stub.Retrieve(retrieval_pb2.RetrieveRequest(semantic_retrieval=True, cases=cases, queries=[query], limit=10, nlp_config=config, scheme_handling=retrieval_pb2.SchemeHandling.SCHEME_HANDLING_BINARY))


In [99]:
mac_responses = retrieve_mac("./microtexts-retrieval-complex/allow_shops_to_open_on_holidays_and_sundays.json")

In [114]:
mac_ids = {mac_graph.id: mac_graph.similarity for mac_graph in mac_responses.query_responses[0].semantic_ranking}


In [115]:
from torch.nn.functional import cosine_similarity
# evaluate first request
def retrieve(file):
    image = Image.open(file).convert("RGB")
    request_embedding = model(processor(image, return_tensors="pt").pixel_values).pooler_output.squeeze()

    similarities = {}
    # for graph_id in mac_ids: for Vorfilterung durch mac
    for graph_id, sim in mac_ids.items():
        graph_embedding = cb[graph_id]
        similarities[graph_id] = (cosine_similarity(request_embedding, graph_embedding, dim=0).item() + sim)/2
    sorted_similarities = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
    return sorted_similarities[:10]

In [116]:
retrieve("./microtexts-retrieval-complex-images/allow_shops_to_open_on_holidays_and_sundays.png")

[('nodeset6375', 0.9750460138855356),
 ('nodeset6433', 0.9615779791055248),
 ('nodeset6468', 0.9550586080760983),
 ('nodeset6454', 0.9528995080539355),
 ('nodeset6361', 0.9528013484108615),
 ('nodeset6427', 0.950915848327102),
 ('nodeset6451', 0.9493817890795266),
 ('nodeset6382', 0.9462347555929025),
 ('nodeset6362', 0.9421126885113453),
 ('nodeset6450', 0.9307562136504464)]