In [18]:
import os
import torch
import json
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

Set seed for reproducibility

In [2]:
def set_seed(seed=7):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(7)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
# pip install git+https://github.com/Babelscape/rebel.git

In [5]:
# pip install sentencepiece  # needed for tokenizer

In [6]:
# pip install transformers

In [19]:
class TripletExtractor:
    def __init__(self):
        self.model_name = "Babelscape/rebel-large"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(device).eval()
        self.prefix = "extract triples: "

    def extract(self, sentence):
        input_text = self.prefix + sentence
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True
        ).to(device)

        outputs = self.model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=256,
            num_beams=4,
            early_stopping=True
        )

        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
        return self._parse(decoded)

    def _parse(self, text):
        triplets = []
        if "<triplet>" not in text:
            return triplets

        segments = text.split("<triplet>")
        for segment in segments[1:]:
            try:
                head = segment.split("<subj>")[0].strip()
                relation = segment.split("<subj>")[1].split("<obj>")[0].strip()
                tail = segment.split("<obj>")[1].split("</s>")[0].strip()
                triplets.append((head, relation, tail))
            except Exception:
                continue
        return triplets

In [20]:
def extract_triplets_from_dataframe(df, extractor, output_path):
    results = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        text = row["text"]
        triplets = extractor.extract(text)
        if triplets:
            results.append({
                "text": text,
                "triplets": triplets
            })

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"✅ Saved {len(results)} extracted triplet entries to {output_path}")

In [None]:
import pandas as pd

In [None]:
label_map = {
    "pants-fire": 0,
    "false": 0,
    "barely-true": 0,
    "half-true": 1,
    "mostly-true": 1,
    "true": 1
}

In [None]:
def load_liar_data(file_path):
    df = pd.read_csv(file_path, sep="\t", header=None)
    df = df[[1, 2]]  # label, statement
    df.columns = ["label", "text"]
    df = df[df["label"].isin(label_map)]
    df["label"] = df["label"].map(label_map)
    return df.reset_index(drop=True)

In [None]:
train_df = load_liar_data("data/liar_dataset/train.tsv")
val_df = load_liar_data("data/liar_dataset/valid.tsv")
test_df = load_liar_data("data/liar_dataset/test.tsv")

Train: 10240 | Val: 1284 | Test: 1267


In [None]:
print("Train:", len(train_df), "| Val:", len(val_df), "| Test:", len(test_df))

In [None]:
extractor = TripletExtractor()

In [None]:
# # 🧪 Sample 100 entries from each
# train_sample = train_df.sample(100, random_state=42).reset_index(drop=True)
# val_sample = val_df.sample(100, random_state=42).reset_index(drop=True)
# test_sample = test_df.sample(100, random_state=42).reset_index(drop=True)

# # 🏃 Run triplet extraction
# extract_triplets_from_dataframe(train_sample, extractor, "data/triplets/sample_triplets_train.json")
# extract_triplets_from_dataframe(val_sample, extractor, "data/triplets/sample_triplets_val.json")
# extract_triplets_from_dataframe(test_sample, extractor, "data/triplets/sample_triplets_test.json")

In [22]:
extract_triplets_from_dataframe(train_df, extractor, "data/triplets/triplets_train.json")
extract_triplets_from_dataframe(val_df, extractor, "data/triplets/triplets_val.json")
extract_triplets_from_dataframe(test_df, extractor, "data/triplets/triplets_test.json")

100%|██████████| 10240/10240 [1:36:06<00:00,  1.78it/s]    


✅ Saved 10240 extracted triplet entries to data/triplets_train.json


100%|██████████| 1284/1284 [02:52<00:00,  7.43it/s]


✅ Saved 1284 extracted triplet entries to data/triplets_val.json


100%|██████████| 1267/1267 [02:44<00:00,  7.68it/s]

✅ Saved 1267 extracted triplet entries to data/triplets_test.json



