In [1]:
!pip -q install transformers
!pip -q install accelerate>=0.12.0
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m32.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd drive/MyDrive/SCIQA

Mounted at /content/drive
/content/drive/MyDrive/SCIQA


In [11]:
import json
from tqdm import tqdm

from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from datasets import load_dataset
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def clean(st):
    st = st.replace("\n", " ")
    st = st.replace("?", " ?")
    st = st.replace("{", " { ")
    st = st.replace("}", " } ")
    st = st.replace("\\'", "'")

    while "  " in st:
        st = st.replace("  ", " ")
    return st.strip()


def get_entities(query):
    query = clean(query)
    entities = []
    relations = []
    all_good = []
    words = query.split(" ")
    for word in words:
        if word.startswith("orkg"):
            all_good.append(word)

    for word in all_good:
        if word.startswith("orkgp:"):
            relations.append(word)
        else:
            entities.append(word)

    return {"entities": entities, "relations": relations}


prefix = "translate English to Sparql: "
tokenizer = AutoTokenizer.from_pretrained("sciqa_T5_model_we")
model = AutoModelForSeq2SeqLM.from_pretrained("sciqa_T5_model_we").to(device)

# books = load_dataset("json", data_files={'test':'test.json'})
# books = load_dataset("orkg/SciQA")
books = load_dataset("awalesushil/DBLP-QuAD")
print(books["test"])

queries = []
sparql = []

for feature in books["test"]:
    # ents = get_entities(feature["query"]["sparql"])
    # query = prefix + feature.get("question").get("string") + "\nentities: " + str(ents.get("entities")) + "\nrelations: " + str(ents.get("relations"))
    query = prefix + feature.get("question").get("string") + "\nentities: " + str(feature.get("entities")) + "\nrelations: " + str(feature.get("relations"))
    queries.append(query)
    gold_sparql = feature.get("query").get("sparql")
    sparql.append(gold_sparql)

print(len(queries))

def divide_chunks(l_, n_):
    # looping till length l
    for i_ in range(0, len(l_), n_):
        yield l_[i_:i_ + n_]

n = 10

q = list(divide_chunks(queries, n))

gs = []
gst = []
i = 0

for group in tqdm(q):
    # print(str(i)+"%", end="  ")
    # i += 0.5
    inputs = tokenizer(group, max_length=512, truncation=True, return_tensors='pt', padding=True).to(device)
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=512, do_sample=True, top_k=30, top_p=0.95)

    generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    generated_texts2 = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)

    generated_texts2 = [x.replace("<pad>", "").replace("</s>", "").strip() for x in generated_texts2]

    gs += generated_texts
    gst += generated_texts2

result = {"questions": queries, "sparql": sparql, "generated_sparql": gs, "generated_with_special_tokens": gst}

with open("SCIQA_we_ft_T5_results_DBLP_we.json", "w", encoding="utf-8") as text_file:
    print(json.dumps(result), file=text_file)

Dataset({
    features: ['id', 'query_type', 'question', 'paraphrased_question', 'query', 'template_id', 'entities', 'relations', 'temporal', 'held_out'],
    num_rows: 2000
})
2000


100%|██████████| 200/200 [33:24<00:00, 10.02s/it]
