# Installation

In [None]:
# pip install torch transformers datasets dotenv scikit-learn tqdm

# Pre-Processing

In [1]:
MAX_LENGTH = 256
EPOCHS = 10

In [2]:
from transformers import BertTokenizerFast, BertModel, get_scheduler
from torch.utils.data import Dataset, DataLoader, random_split
from pathlib import Path
from tqdm import tqdm
from torch import nn
import torch
import json
import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_path = "data/ConvAI2/u2t_map_all.json"
with open(data_path, "r") as f:
    raw_data = json.load(f)

# Extract all unique relations
all_relations = sorted({ex["triplets"][0]["label"] for ex in raw_data})
relation2id = {rel: i for i, rel in enumerate(all_relations)}
id2relation = {i: rel for rel, i in relation2id.items()}
print(f"Loaded {len(raw_data)} examples with {len(relation2id)} relation types.")


Loaded 35077 examples with 105 relation types.


In [4]:
def convert_to_bio(example):
    triplet = example["triplets"][0]
    tokens = triplet["tokens"]
    head = triplet["head"]
    tail = triplet["tail"]
    relation = triplet["label"]

    labels = ['O'] * len(tokens)
    for idx in head:
        labels[idx] = 'B-SUB'
    if isinstance(tail, list):
        for i, idx in enumerate(tail):
            labels[idx] = 'B-OBJ' if i == 0 else 'I-OBJ'

    return {
        "tokens": tokens,
        "labels": labels,
        "relation_label": relation2id[relation]
    }

bio_data = [convert_to_bio(ex) for ex in raw_data if "triplets" in ex and ex["triplets"]]

In [5]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

label_list = ['O', 'B-SUB', 'B-OBJ', 'I-OBJ']
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for label, i in label2id.items()}

def tokenize_and_align(example):
    encoding = tokenizer(example["tokens"], is_split_into_words=True, padding="max_length", truncation=True, max_length= MAX_LENGTH)
    word_ids = encoding.word_ids()

    aligned_labels = []
    prev_word_id = None
    for word_id in word_ids:
        if word_id is None:
            aligned_labels.append(-100)
        elif word_id != prev_word_id:
            aligned_labels.append(label2id[example["labels"][word_id]])
        else:
            aligned_labels.append(label2id[example["labels"][word_id]] if 'I' in example["labels"][word_id] else label2id['O'])
        prev_word_id = word_id

    encoding["labels"] = aligned_labels
    encoding["relation_label"] = example["relation_label"]
    return encoding

tokenized_data = [tokenize_and_align(ex) for ex in bio_data]


In [6]:
class JointPersonaDataset(Dataset):
    def __init__(self, encodings):
        self.data = encodings

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"]),
            "attention_mask": torch.tensor(item["attention_mask"]),
            "labels": torch.tensor(item["labels"]),
            "relation_label": torch.tensor(item["relation_label"])
        }

dataset = JointPersonaDataset(tokenized_data)
train_size = int(0.9 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)


In [7]:
class JointBertExtractor(nn.Module):
    def __init__(self, base_model='bert-base-uncased', num_token_labels=4, num_relation_labels=0):
        super(JointBertExtractor, self).__init__()
        self.bert = BertModel.from_pretrained(base_model)

        self.dropout = nn.Dropout(0.1)
        self.token_classifier = nn.Linear(self.bert.config.hidden_size, num_token_labels)
        self.relation_classifier = nn.Linear(self.bert.config.hidden_size, num_relation_labels)

    def forward(self, input_ids, attention_mask, labels=None, relation_label=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)
        cls_output = self.dropout(outputs.pooler_output)

        token_logits = self.token_classifier(sequence_output)  # [batch_size, seq_len, num_token_labels]
        relation_logits = self.relation_classifier(cls_output)  # [batch_size, num_relation_labels]

        loss = None
        if labels is not None and relation_label is not None:
            loss_fct = nn.CrossEntropyLoss()
            token_loss = loss_fct(token_logits.view(-1, token_logits.shape[-1]), labels.view(-1))
            relation_loss = loss_fct(relation_logits, relation_label)
            loss = token_loss + relation_loss

        return {
            "loss": loss,
            "token_logits": token_logits,
            "relation_logits": relation_logits
        }


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = JointBertExtractor(
    num_token_labels=len(label2id),
    num_relation_labels=len(relation2id)
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
num_epochs = 5
num_training_steps = num_epochs * len(train_loader)

lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


In [9]:
from transformers import  get_scheduler

from torch.optim import AdamW


optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = EPOCHS
num_training_steps = num_epochs * len(train_loader)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


In [10]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            relation_label=batch["relation_label"]
        )

        loss = outputs["loss"]
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} — Train Loss: {avg_train_loss:.4f}")


Epoch 1:   0%|          | 0/1974 [00:00<?, ?it/s]

Epoch 1: 100%|██████████| 1974/1974 [03:45<00:00,  8.75it/s]


Epoch 1 — Train Loss: 2.2909


Epoch 2: 100%|██████████| 1974/1974 [03:48<00:00,  8.63it/s]


Epoch 2 — Train Loss: 1.2030


Epoch 3: 100%|██████████| 1974/1974 [03:48<00:00,  8.62it/s]


Epoch 3 — Train Loss: 0.7514


Epoch 4: 100%|██████████| 1974/1974 [03:49<00:00,  8.62it/s]


Epoch 4 — Train Loss: 0.4748


Epoch 5: 100%|██████████| 1974/1974 [03:48<00:00,  8.63it/s]


Epoch 5 — Train Loss: 0.2989


Epoch 6: 100%|██████████| 1974/1974 [03:47<00:00,  8.68it/s]


Epoch 6 — Train Loss: 0.1930


Epoch 7: 100%|██████████| 1974/1974 [03:47<00:00,  8.67it/s]


Epoch 7 — Train Loss: 0.1175


Epoch 8: 100%|██████████| 1974/1974 [03:48<00:00,  8.65it/s]


Epoch 8 — Train Loss: 0.0785


Epoch 9: 100%|██████████| 1974/1974 [03:48<00:00,  8.64it/s]


Epoch 9 — Train Loss: 0.0499


Epoch 10: 100%|██████████| 1974/1974 [03:48<00:00,  8.63it/s]

Epoch 10 — Train Loss: 0.0345





In [12]:
model_save_path = "PExtractor"
tokenizer.save_pretrained(model_save_path)
torch.save(model.state_dict(), os.path.join(model_save_path, "pytorch_model.bin"))
print(f"Model saved to {model_save_path}")

Model saved to PExtractor


# Load and Test

In [1]:
import torch
from transformers import BertTokenizerFast
from torch import nn
from transformers import BertModel


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class JointBertExtractor(nn.Module):
    def __init__(self, base_model='bert-base-uncased', num_token_labels=4, num_relation_labels=0):
        super(JointBertExtractor, self).__init__()
        self.bert = BertModel.from_pretrained(base_model)

        self.dropout = nn.Dropout(0.1)
        self.token_classifier = nn.Linear(self.bert.config.hidden_size, num_token_labels)
        self.relation_classifier = nn.Linear(self.bert.config.hidden_size, num_relation_labels)

    def forward(self, input_ids, attention_mask, labels=None, relation_label=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)
        cls_output = self.dropout(outputs.pooler_output)

        token_logits = self.token_classifier(sequence_output)
        relation_logits = self.relation_classifier(cls_output)

        loss = None
        if labels is not None and relation_label is not None:
            loss_fct = nn.CrossEntropyLoss()
            token_loss = loss_fct(token_logits.view(-1, token_logits.shape[-1]), labels.view(-1))
            relation_loss = loss_fct(relation_logits, relation_label)
            loss = token_loss + relation_loss

        return {
            "loss": loss,
            "token_logits": token_logits,
            "relation_logits": relation_logits
        }


In [3]:
model_path = "PExtractor"  # adjust if saved elsewhere
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = JointBertExtractor(num_token_labels=4, num_relation_labels=105)  # 105 = your # of relation classes
model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device))
model.to(device)
model.eval()


JointBertExtractor(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [4]:
# Token-level labels
label_list = ['O', 'B-SUB', 'B-OBJ', 'I-OBJ']
id2label = {i: label for i, label in enumerate(label_list)}

# Relation labels (recreate from training set)
import json
with open("data/ConvAI2/u2t_map_all.json", "r") as f:
    raw_data = json.load(f)
relation_list = sorted({ex["triplets"][0]["label"] for ex in raw_data})
id2relation = {i: rel for i, rel in enumerate(relation_list)}


In [7]:
def extract_triplet_joint(sentence: str, model, tokenizer, id2label, id2relation, device="cpu"):
    model.eval()
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])

    token_preds = torch.argmax(outputs["token_logits"], dim=-1).squeeze().cpu().tolist()
    tokens_decoded = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
    attention_mask = inputs["attention_mask"].squeeze().cpu().tolist()

    subject = None
    obj_tokens = []

    for token, label_id, mask in zip(tokens_decoded, token_preds, attention_mask):
        if mask == 0 or token in ["[PAD]", "[CLS]", "[SEP]"]:
            continue  # skip padding tokens
        label = id2label.get(label_id, "O")
        if label == "B-SUB":
            subject = token
        elif label.startswith("B-OBJ") or label.startswith("I-OBJ"):
            obj_tokens.append(token)


    rel_pred_id = torch.argmax(outputs["relation_logits"], dim=-1).item()
    relation = id2relation[rel_pred_id]
    object_str = tokenizer.convert_tokens_to_string(obj_tokens).strip()
    subject = subject if subject else "i"

    return (subject, relation, object_str)


In [8]:
test_sent = "I just got done watching a horror movie."
s, r, o = extract_triplet_joint(test_sent, model, tokenizer, id2label, id2relation, device)
print(f"Extracted Triplet: ({s}, {r}, {o})")


Extracted Triplet: (i, favorite_movie, horror)
