In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


实验


In [None]:
import os
import xml.etree.ElementTree as ET
import spacy
from transformers import BertTokenizerFast

# Initialize spacy and BERT tokenizer
nlp = spacy.load("en_core_web_sm")
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# Function to read and parse XML files
def parse_xml_files(directory):
    data = []
    for filename in os.listdir(directory):
        if filename.endswith(".xml"):
            try:
                tree = ET.parse(os.path.join(directory, filename))
                root = tree.getroot()
                text_content = root.find("TEXT").text
                data.append((root, text_content))
            except ET.ParseError:
                print(f"Error parsing {filename}, skipping...")
    return data

# Function to extract EVENT and TIMEX3 entities
def extract_entities(xml_data):
    data = []
    for entry, text in xml_data:
        events = []
        timex3s = []
        tlinks = []

        for tag in entry.find("TAGS"):
            if tag.tag == "EVENT" and tag.attrib["type"] == "TREATMENT":
                events.append({
                    "id": tag.attrib["id"],
                    "text": tag.attrib["text"],
                    "start": int(tag.attrib["start"]),
                    "end": int(tag.attrib["end"]),
                    "modality": tag.attrib["modality"],
                    "polarity": tag.attrib["polarity"]
                })
            elif tag.tag == "TIMEX3" and tag.attrib["type"] == "DATE":
                timex3s.append({
                    "id": tag.attrib["id"],
                    "text": tag.attrib["text"],
                    "start": int(tag.attrib["start"]),
                    "end": int(tag.attrib["end"]),
                    "val": tag.attrib["val"]
                })
            elif tag.tag == "TLINK":
                tlinks.append({
                    "fromID": tag.attrib["fromID"],
                    "toID": tag.attrib["toID"],
                    "type": tag.attrib["type"]
                })

        data.append((events, timex3s, tlinks, text))
    return data

# Function to link treatments with dates and annotate all time expressions
def link_treatments_with_dates(treatments, timex3s, tlinks, text):
    linked_data = {}
    for treatment in treatments:
        treatment_id = treatment["id"]
        if treatment_id not in linked_data:
            linked_data[treatment_id] = {
                "treatment": treatment,
                "dates": [],
                "text": text,
                "bio_time": ["O"] * len(text),
                "bio_treatment": ["O"] * len(text),
                "bio_tlink": ["O"] * len(text)
            }

        # Annotate all time expressions
        for timex in timex3s:
            bio_label = "B-DATE"
            linked_data[treatment_id]["bio_time"][timex["start"]] = bio_label
            for i in range(timex["start"] + 1, timex["end"]):
                linked_data[treatment_id]["bio_time"][i] = "I-DATE"

        # Annotate linked dates with TLINK types
        for tlink in tlinks:
            if tlink["fromID"] == treatment_id or tlink["toID"] == treatment_id:
                date_id = tlink["toID"] if tlink["fromID"] == treatment_id else tlink["fromID"]
                date = next((t for t in timex3s if t["id"] == date_id), None)
                if date:
                    linked_data[treatment_id]["dates"].append({
                        "date": date,
                        "tlink_type": tlink["type"]
                    })
                    bio_label = f"B-{tlink['type']}"
                    linked_data[treatment_id]["bio_tlink"][date["start"]] = bio_label
                    for i in range(date["start"] + 1, date["end"]):
                        linked_data[treatment_id]["bio_tlink"][i] = f"I-{tlink['type']}"

        # Annotate treatments
        bio_label_treatment = "B-TREATMENT"
        linked_data[treatment_id]["bio_treatment"][treatment["start"]] = bio_label_treatment
        for i in range(treatment["start"] + 1, treatment["end"]):
            linked_data[treatment_id]["bio_treatment"][i] = "I-TREATMENT"

    return linked_data.values()

# Function to align BIO annotations with tokenized text
def align_bio_with_tokens(text, bio_annotations, tokenizer):
    tokens = tokenizer.tokenize(text)
    token_offsets = tokenizer(text, return_offsets_mapping=True)["offset_mapping"]
    # print("Token Offsets and Corresponding Tokens:")
    # for token, (token_start, token_end) in zip(tokens, token_offsets):
    #     print(f"Token: {token}, Start: {token_start}, End: {token_end}")
    aligned_bio = []
    prev_tag = "O"
    for token_start, token_end in token_offsets:
        if token_start == token_end:
            continue
        if token_start >= len(bio_annotations):
            aligned_bio.append("O")
            # print(f"Token start {token_start} exceeds BIO annotations length {len(bio_annotations)}, appending 'O'")
            continue

        bio_tag = bio_annotations[token_start]
        if bio_tag.startswith("I-") and prev_tag == "O":
            bio_tag = "B-" + bio_tag[2:]
        aligned_bio.append(bio_tag)
        prev_tag = bio_tag
    # print(aligned_bio)
    # Debugging: Print tokens and corresponding BIO tags that are not "O"
    # for token, bio_tag in zip(tokens, aligned_bio):
    #     if bio_tag != "O":
    #         print(f"Token: {token}, BIO Tag: {bio_tag}")

    return tokens, aligned_bio

# Function to chunk annotations around target sentences
def chunk_annotations(text, bio_time, bio_treatment, bio_tlink, tokenizer, window_size):
    # Tokenize text into sentences
    doc = nlp(text)
    sentences = list(doc.sents)

    # Find sentence boundaries
    sentence_boundaries = [(sent.start_char, sent.end_char) for sent in sentences]

    # Identify sentences containing target entities
    target_indices = []
    for i, (start, end) in enumerate(sentence_boundaries):
        if (any(tag.startswith("B-") or tag.startswith("I-") for tag in bio_time[start:end]) or
                any(tag.startswith("B-") or tag.startswith("I-") for tag in bio_treatment[start:end]) or
                    any(tag.startswith("B-") or tag.startswith("I-") for tag in bio_tlink[start:end])):
            target_indices.append(i)

    # Extract sentences around target indices based on window size
    selected_indices = set()
    for idx in target_indices:
        start_idx = max(0, idx - window_size)
        end_idx = min(len(sentences), idx + window_size + 1)
        selected_indices.update(range(start_idx, end_idx))

    # Combine selected sentences and update BIO annotations
    selected_indices = sorted(selected_indices)
    selected_text = " ".join([sentences[i].text for i in selected_indices])

    selected_bio_time = []
    selected_bio_treatment = []
    selected_bio_tlink = []

    for idx in selected_indices:
        start, end = sentence_boundaries[idx]
        selected_bio_time.extend(bio_time[start:end])
        selected_bio_treatment.extend(bio_treatment[start:end])
        selected_bio_tlink.extend(bio_tlink[start:end])

    # Debugging: Print selected sentences and BIO annotations
    # print("Selected Sentences:\n", selected_text)
    # print("Selected BIO Time (non-O):")
    # for i, tag in enumerate(selected_bio_time):
    #     if tag != "O":
    #         print(f"Token: {selected_text[i]}, BIO Tag: {tag}")

    # print("Selected BIO Treatment (non-O):")
    # for i, tag in enumerate(selected_bio_treatment):
    #     if tag != "O":
    #         print(f"Token: {selected_text[i]}, BIO Tag: {tag}")

    # print("Selected BIO TLINK (non-O):")
    # for i, tag in enumerate(selected_bio_tlink):
    #     if tag != "O":
    #         print(f"Token: {selected_text[i]}, BIO Tag: {tag}")

    # Align new selected text with BERT tokens

    tokens, aligned_bio_time = align_bio_with_tokens(selected_text, selected_bio_time, tokenizer)
    _, aligned_bio_treatment = align_bio_with_tokens(selected_text, selected_bio_treatment, tokenizer)
    _, aligned_bio_tlink = align_bio_with_tokens(selected_text, selected_bio_tlink, tokenizer)

    return selected_text, tokens, aligned_bio_time, aligned_bio_treatment, aligned_bio_tlink


# Main processing
xml_dir = "/content/drive/MyDrive/UOM_year3/3rd_year_project/2012_Temporal_Relations_Challenge/2012-07-06.release-fix"
xml_data = parse_xml_files(xml_dir)
extracted_data = extract_entities(xml_data)
temp = extracted_data[0:1]
# print(temp)
linked_treatments = []
for events, timex3s, tlinks, text in temp:
    linked_treatments.extend(link_treatments_with_dates(events, timex3s, tlinks, text))
print("/////////")
print(linked_treatments)
# Align BIO annotations with BERT tokens for each entry
window_size = 0
processed_data = []
for entry in linked_treatments:
    text = entry["text"]
    bio_time = entry["bio_time"]
    bio_treatment = entry["bio_treatment"]
    bio_tlink = entry["bio_tlink"]

    try:
        selected_text, tokens, aligned_bio_time, aligned_bio_treatment, aligned_bio_tlink = chunk_annotations(
            text, bio_time, bio_treatment, bio_tlink, tokenizer, window_size
        )

        processed_data.append({
            "selected_text": selected_text,
            "tokens": tokens,
            "aligned_bio_time": aligned_bio_time,
            "aligned_bio_treatment": aligned_bio_treatment,
            "aligned_bio_tlink": aligned_bio_tlink
        })
    except Exception as e:
        print(f"Error processing entry with text: {text[:50]}... Error: {e}")

# Example output
for entry in processed_data[:1]:  # printing only the first entry for brevity
    print("Selected Text:", entry["selected_text"])
    print("Tokens:", entry["tokens"])
    print(len(entry["tokens"]))
    print("Aligned BIO Time:", entry["aligned_bio_time"])
    print(len(entry["aligned_bio_time"]))
    print("Aligned BIO Treatment:", entry["aligned_bio_treatment"])
    print(len(entry["aligned_bio_treatment"]))
    print("Aligned BIO TLINK:", entry["aligned_bio_tlink"])
    print(len(entry["aligned_bio_tlink"]))
    print("\n")

    # Collect non-O entities for BIO time annotations
    bio_time_entities = []
    current_entity = []
    current_label = None
    for token, bio_tag in zip(entry["tokens"], entry["aligned_bio_time"]):
        if bio_tag.startswith("B-"):
            if current_entity:
                bio_time_entities.append((current_label, current_entity))
                current_entity = []
            current_label = bio_tag[2:]
            current_entity.append(token)
        elif bio_tag.startswith("I-") and current_label == bio_tag[2:]:
            current_entity.append(token)
        else:
            if current_entity:
                bio_time_entities.append((current_label, current_entity))
                current_entity = []
            current_label = None

    if current_entity:
        bio_time_entities.append((current_label, current_entity))

    # Collect non-O entities for BIO treatment annotations
    bio_treatment_entities = []
    current_entity = []
    current_label = None
    for token, bio_tag in zip(entry["tokens"], entry["aligned_bio_treatment"]):
        if bio_tag.startswith("B-"):
            if current_entity:
                bio_treatment_entities.append((current_label, current_entity))
                current_entity = []
            current_label = bio_tag[2:]
            current_entity.append(token)
        elif bio_tag.startswith("I-") and current_label == bio_tag[2:]:
            current_entity.append(token)
        else:
            if current_entity:
                bio_treatment_entities.append((current_label, current_entity))
                current_entity = []
            current_label = None

    if current_entity:
        bio_treatment_entities.append((current_label, current_entity))

    # Collect non-O entities for BIO tlink annotations
    bio_tlink_entities = []
    current_entity = []
    current_label = None
    for token, bio_tag in zip(entry["tokens"], entry["aligned_bio_tlink"]):
        if bio_tag.startswith("B-"):
            if current_entity:
                bio_tlink_entities.append((current_label, current_entity))
                current_entity = []
            current_label = bio_tag[2:]
            current_entity.append(token)
        elif bio_tag.startswith("I-") and current_label == bio_tag[2:]:
            current_entity.append(token)
        else:
            if current_entity:
                bio_tlink_entities.append((current_label, current_entity))
                current_entity = []
            current_label = None

    if current_entity:
        bio_tlink_entities.append((current_label, current_entity))

    print("BIO Time Entities:")
    for label, tokens in bio_time_entities:
        print(f"Entity: {' '.join(tokens)}, Label: {label}")

    print("\nBIO Treatment Entities:")
    for label, tokens in bio_treatment_entities:
        print(f"Entity: {' '.join(tokens)}, Label: {label}")

    print("\nBIO TLINK Entities:")
    for label, tokens in bio_tlink_entities:
        print(f"Entity: {' '.join(tokens)}, Label: {label}")


KeyboardInterrupt: 

试验后运行

In [2]:
import os
import xml.etree.ElementTree as ET
import spacy
from transformers import BertTokenizerFast

# Initialize spacy and BERT tokenizer
nlp = spacy.load("en_core_web_sm")
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# Function to read and parse XML files
def parse_xml_files(directory):
    data = []
    for filename in os.listdir(directory):
        if filename.endswith(".xml"):
            try:
                tree = ET.parse(os.path.join(directory, filename))
                root = tree.getroot()
                text_content = root.find("TEXT").text
                data.append((root, text_content))
            except ET.ParseError:
                print(f"Error parsing {filename}, skipping...")
    return data

# Function to extract EVENT and TIMEX3 entities
def extract_entities(xml_data):
    data = []
    for entry, text in xml_data:
        events = []
        timex3s = []
        tlinks = []

        for tag in entry.find("TAGS"):
            if tag.tag == "EVENT" and tag.attrib["type"] == "TREATMENT":
                events.append({
                    "id": tag.attrib["id"],
                    "text": tag.attrib["text"],
                    "start": int(tag.attrib["start"]),
                    "end": int(tag.attrib["end"]),
                    "modality": tag.attrib["modality"],
                    "polarity": tag.attrib["polarity"]
                })
            elif tag.tag == "TIMEX3" and tag.attrib["type"] == "DATE":
                timex3s.append({
                    "id": tag.attrib["id"],
                    "text": tag.attrib["text"],
                    "start": int(tag.attrib["start"]),
                    "end": int(tag.attrib["end"]),
                    "val": tag.attrib["val"]
                })
            elif tag.tag == "TLINK":
                tlinks.append({
                    "fromID": tag.attrib["fromID"],
                    "toID": tag.attrib["toID"],
                    "type": tag.attrib["type"]
                })

        data.append((events, timex3s, tlinks, text))
    return data

# Function to link treatments with dates and annotate all time expressions
def link_treatments_with_dates(treatments, timex3s, tlinks, text):
    linked_data = {}
    for treatment in treatments:
        treatment_id = treatment["id"]
        if treatment_id not in linked_data:
            linked_data[treatment_id] = {
                "treatment": treatment,
                "dates": [],
                "text": text,
                "bio_time": ["O"] * len(text),
                "bio_treatment": ["O"] * len(text),
                "bio_tlink": ["O"] * len(text)
            }

        # Annotate all time expressions
        for timex in timex3s:
            bio_label = "B-DATE"
            linked_data[treatment_id]["bio_time"][timex["start"]] = bio_label
            for i in range(timex["start"] + 1, timex["end"]):
                linked_data[treatment_id]["bio_time"][i] = "I-DATE"

        # Annotate linked dates with TLINK types
        for tlink in tlinks:
            if tlink["fromID"] == treatment_id or tlink["toID"] == treatment_id:
                date_id = tlink["toID"] if tlink["fromID"] == treatment_id else tlink["fromID"]
                date = next((t for t in timex3s if t["id"] == date_id), None)
                if date:
                    linked_data[treatment_id]["dates"].append({
                        "date": date,
                        "tlink_type": tlink["type"]
                    })
                    bio_label = f"B-{tlink['type']}"
                    linked_data[treatment_id]["bio_tlink"][date["start"]] = bio_label
                    for i in range(date["start"] + 1, date["end"]):
                        linked_data[treatment_id]["bio_tlink"][i] = f"I-{tlink['type']}"

        # Annotate treatments
        bio_label_treatment = "B-TREATMENT"
        linked_data[treatment_id]["bio_treatment"][treatment["start"]] = bio_label_treatment
        for i in range(treatment["start"] + 1, treatment["end"]):
            linked_data[treatment_id]["bio_treatment"][i] = "I-TREATMENT"

    return linked_data.values()

# Function to align BIO annotations with tokenized text
def align_bio_with_tokens(text, bio_annotations, tokenizer):
    tokens = tokenizer.tokenize(text)
    token_offsets = tokenizer(text, return_offsets_mapping=True)["offset_mapping"]
    # print("Token Offsets and Corresponding Tokens:")
    # for token, (token_start, token_end) in zip(tokens, token_offsets):
    #     print(f"Token: {token}, Start: {token_start}, End: {token_end}")
    aligned_bio = []
    prev_tag = "O"
    for token_start, token_end in token_offsets:
        if token_start == token_end:
            continue
        if token_start >= len(bio_annotations):
            aligned_bio.append("O")
            # print(f"Token start {token_start} exceeds BIO annotations length {len(bio_annotations)}, appending 'O'")
            continue

        bio_tag = bio_annotations[token_start]
        if bio_tag.startswith("I-") and prev_tag == "O":
            bio_tag = "B-" + bio_tag[2:]
        aligned_bio.append(bio_tag)
        prev_tag = bio_tag
    # print(aligned_bio)
    # Debugging: Print tokens and corresponding BIO tags that are not "O"
    # for token, bio_tag in zip(tokens, aligned_bio):
    #     if bio_tag != "O":
    #         print(f"Token: {token}, BIO Tag: {bio_tag}")

    return tokens, aligned_bio

# Function to chunk annotations around target sentences
def chunk_annotations(text, bio_time, bio_treatment, bio_tlink, tokenizer, window_size):
    # Tokenize text into sentences
    doc = nlp(text)
    sentences = list(doc.sents)

    # Find sentence boundaries
    sentence_boundaries = [(sent.start_char, sent.end_char) for sent in sentences]

    # Identify sentences containing target entities
    target_indices = []
    for i, (start, end) in enumerate(sentence_boundaries):
        if (any(tag.startswith("B-") or tag.startswith("I-") for tag in bio_time[start:end]) or
                any(tag.startswith("B-") or tag.startswith("I-") for tag in bio_treatment[start:end]) or
                    any(tag.startswith("B-") or tag.startswith("I-") for tag in bio_tlink[start:end])):
            target_indices.append(i)

    # Extract sentences around target indices based on window size
    selected_indices = set()
    for idx in target_indices:
        start_idx = max(0, idx - window_size)
        end_idx = min(len(sentences), idx + window_size + 1)
        selected_indices.update(range(start_idx, end_idx))

    # Combine selected sentences and update BIO annotations
    selected_indices = sorted(selected_indices)
    selected_text = " ".join([sentences[i].text for i in selected_indices])

    selected_bio_time = []
    selected_bio_treatment = []
    selected_bio_tlink = []

    for idx in selected_indices:
        start, end = sentence_boundaries[idx]
        selected_bio_time.extend(bio_time[start:end])
        selected_bio_treatment.extend(bio_treatment[start:end])
        selected_bio_tlink.extend(bio_tlink[start:end])

    # Debugging: Print selected sentences and BIO annotations
    # print("Selected Sentences:\n", selected_text)
    # print("Selected BIO Time (non-O):")
    # for i, tag in enumerate(selected_bio_time):
    #     if tag != "O":
    #         print(f"Token: {selected_text[i]}, BIO Tag: {tag}")

    # print("Selected BIO Treatment (non-O):")
    # for i, tag in enumerate(selected_bio_treatment):
    #     if tag != "O":
    #         print(f"Token: {selected_text[i]}, BIO Tag: {tag}")

    # print("Selected BIO TLINK (non-O):")
    # for i, tag in enumerate(selected_bio_tlink):
    #     if tag != "O":
    #         print(f"Token: {selected_text[i]}, BIO Tag: {tag}")

    # Align new selected text with BERT tokens

    tokens, aligned_bio_time = align_bio_with_tokens(selected_text, selected_bio_time, tokenizer)
    _, aligned_bio_treatment = align_bio_with_tokens(selected_text, selected_bio_treatment, tokenizer)
    _, aligned_bio_tlink = align_bio_with_tokens(selected_text, selected_bio_tlink, tokenizer)

    return selected_text, tokens, aligned_bio_time, aligned_bio_treatment, aligned_bio_tlink


# Main processing
xml_dir = "/content/drive/MyDrive/UOM_year3/3rd_year_project/2012_Temporal_Relations_Challenge/2012-07-06.release-fix"
xml_data = parse_xml_files(xml_dir)
extracted_data = extract_entities(xml_data)
# temp = extracted_data[0:1]
# print(temp)
linked_treatments = []
for events, timex3s, tlinks, text in extracted_data:
    linked_treatments.extend(link_treatments_with_dates(events, timex3s, tlinks, text))
# print("/////////")
# print(linked_treatments)
# Align BIO annotations with BERT tokens for each entry
window_size = 2
processed_data = []
for entry in linked_treatments:
    text = entry["text"]
    bio_time = entry["bio_time"]
    bio_treatment = entry["bio_treatment"]
    bio_tlink = entry["bio_tlink"]

    try:
        selected_text, tokens, aligned_bio_time, aligned_bio_treatment, aligned_bio_tlink = chunk_annotations(
            text, bio_time, bio_treatment, bio_tlink, tokenizer, window_size
        )

        # processed_data.append({
        #     "selected_text": selected_text,
        #     "tokens": tokens,
        #     "aligned_bio_time": aligned_bio_time,
        #     "aligned_bio_treatment": aligned_bio_treatment,
        #     "aligned_bio_tlink": aligned_bio_tlink
        # })
        processed_data.append({
            "tokens": tokens,
            "bio_time_aligned": aligned_bio_tlink,
            "bio_treatment_aligned": aligned_bio_treatment,
        })
    except Exception as e:
        print(f"Error processing entry with text: {text[:50]}... Error: {e}")

# Example output
# for entry in processed_data[:1]:  # printing only the first entry for brevity
#     print("Selected Text:", entry["selected_text"])
#     print("Tokens:", entry["tokens"])
#     print(len(entry["tokens"]))
#     print("Aligned BIO Time:", entry["aligned_bio_time"])
#     print(len(entry["aligned_bio_time"]))
#     print("Aligned BIO Treatment:", entry["aligned_bio_treatment"])
#     print(len(entry["aligned_bio_treatment"]))
#     print("Aligned BIO TLINK:", entry["aligned_bio_tlink"])
#     print(len(entry["aligned_bio_tlink"]))
#     print("\n")

#     # Collect non-O entities for BIO time annotations
#     bio_time_entities = []
#     current_entity = []
#     current_label = None
#     for token, bio_tag in zip(entry["tokens"], entry["aligned_bio_time"]):
#         if bio_tag.startswith("B-"):
#             if current_entity:
#                 bio_time_entities.append((current_label, current_entity))
#                 current_entity = []
#             current_label = bio_tag[2:]
#             current_entity.append(token)
#         elif bio_tag.startswith("I-") and current_label == bio_tag[2:]:
#             current_entity.append(token)
#         else:
#             if current_entity:
#                 bio_time_entities.append((current_label, current_entity))
#                 current_entity = []
#             current_label = None

#     if current_entity:
#         bio_time_entities.append((current_label, current_entity))

#     # Collect non-O entities for BIO treatment annotations
#     bio_treatment_entities = []
#     current_entity = []
#     current_label = None
#     for token, bio_tag in zip(entry["tokens"], entry["aligned_bio_treatment"]):
#         if bio_tag.startswith("B-"):
#             if current_entity:
#                 bio_treatment_entities.append((current_label, current_entity))
#                 current_entity = []
#             current_label = bio_tag[2:]
#             current_entity.append(token)
#         elif bio_tag.startswith("I-") and current_label == bio_tag[2:]:
#             current_entity.append(token)
#         else:
#             if current_entity:
#                 bio_treatment_entities.append((current_label, current_entity))
#                 current_entity = []
#             current_label = None

#     if current_entity:
#         bio_treatment_entities.append((current_label, current_entity))

#     # Collect non-O entities for BIO tlink annotations
#     bio_tlink_entities = []
#     current_entity = []
#     current_label = None
#     for token, bio_tag in zip(entry["tokens"], entry["aligned_bio_tlink"]):
#         if bio_tag.startswith("B-"):
#             if current_entity:
#                 bio_tlink_entities.append((current_label, current_entity))
#                 current_entity = []
#             current_label = bio_tag[2:]
#             current_entity.append(token)
#         elif bio_tag.startswith("I-") and current_label == bio_tag[2:]:
#             current_entity.append(token)
#         else:
#             if current_entity:
#                 bio_tlink_entities.append((current_label, current_entity))
#                 current_entity = []
#             current_label = None

#     if current_entity:
#         bio_tlink_entities.append((current_label, current_entity))

#     print("BIO Time Entities:")
#     for label, tokens in bio_time_entities:
#         print(f"Entity: {' '.join(tokens)}, Label: {label}")

#     print("\nBIO Treatment Entities:")
#     for label, tokens in bio_treatment_entities:
#         print(f"Entity: {' '.join(tokens)}, Label: {label}")

#     print("\nBIO TLINK Entities:")
#     for label, tokens in bio_tlink_entities:
#         print(f"Entity: {' '.join(tokens)}, Label: {label}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Error parsing 807.xml, skipping...
Error parsing 422.xml, skipping...


Token indices sequence length is longer than the specified maximum sequence length for this model (637 > 512). Running this sequence through the model will result in indexing errors


In [3]:
import os
import xml.etree.ElementTree as ET
import spacy
from transformers import BertTokenizerFast


# Directory containing the XML files
xml_dir = "/content/drive/MyDrive/UOM_year3/3rd_year_project/2012_Temporal_Relations_Challenge/ground_truth/merged_xml"
# xml_dir = "/content/drive/MyDrive/UOM_year3/3rd_year_project/2012_Temporal_Relations_Challenge/2012-07-06.release-fix"
# Initialize BERT tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# Function to read and parse XML files
def parse_xml_files(directory):
    data = []
    for filename in os.listdir(directory):
        if filename.endswith(".xml"):
            try:
                tree = ET.parse(os.path.join(directory, filename))
                root = tree.getroot()
                text_content = root.find("TEXT").text
                data.append((root, text_content))
            except ET.ParseError:
                print(f"Error parsing {filename}, skipping...")
    return data

# Function to extract EVENT and TIMEX3 entities
def extract_entities(xml_data):
    data = []
    for entry, text in xml_data:
        events = []
        timex3s = []
        tlinks = []

        for tag in entry.find("TAGS"):
            if tag.tag == "EVENT" and tag.attrib["type"] == "TREATMENT":
                events.append({
                    "id": tag.attrib["id"],
                    "text": tag.attrib["text"],
                    "start": int(tag.attrib["start"]),
                    "end": int(tag.attrib["end"]),
                    "modality": tag.attrib["modality"],
                    "polarity": tag.attrib["polarity"]
                })
            elif tag.tag == "TIMEX3" and tag.attrib["type"] == "DATE":
                timex3s.append({
                    "id": tag.attrib["id"],
                    "text": tag.attrib["text"],
                    "start": int(tag.attrib["start"]),
                    "end": int(tag.attrib["end"]),
                    "val": tag.attrib["val"]
                })
            elif tag.tag == "TLINK":
                tlinks.append({
                    "fromID": tag.attrib["fromID"],
                    "toID": tag.attrib["toID"],
                    "type": tag.attrib["type"]
                })

        data.append((events, timex3s, tlinks, text))
    return data

# Function to link treatments with dates and include TLINK type information
def link_treatments_with_dates(treatments, timex3s, tlinks, text):
    linked_data = {}
    for treatment in treatments:
        treatment_id = treatment["id"]
        if treatment_id not in linked_data:
            linked_data[treatment_id] = {
                "treatment": treatment,
                "dates": [],
                "text": text,
                "bio_time": ["O"] * len(text),
                "bio_treatment": ["O"] * len(text)
            }

        for tlink in tlinks:
            if tlink["fromID"] == treatment_id or tlink["toID"] == treatment_id:
                date_id = tlink["toID"] if tlink["fromID"] == treatment_id else tlink["fromID"]
                date = next((t for t in timex3s if t["id"] == date_id), None)
                if date:
                    linked_data[treatment_id]["dates"].append({
                        "date": date,
                        "tlink_type": tlink["type"]
                    })
                    # Update BIO annotation for dates
                    bio_label = f"B-{tlink['type']}"
                    linked_data[treatment_id]["bio_time"][date["start"]] = bio_label
                    for i in range(date["start"] + 1, date["end"]):
                        linked_data[treatment_id]["bio_time"][i] = f"I-{tlink['type']}"

        # Update BIO annotation for treatments
        bio_label_treatment = "B-TREATMENT"
        linked_data[treatment_id]["bio_treatment"][treatment["start"]] = bio_label_treatment
        for i in range(treatment["start"] + 1, treatment["end"]):
            linked_data[treatment_id]["bio_treatment"][i] = "I-TREATMENT"

    return linked_data.values()

# Function to align BIO annotations with tokenized text
def align_bio_with_tokens(text, bio_annotations, tokenizer):
    tokens = tokenizer.tokenize(text)
    token_offsets = tokenizer(text, return_offsets_mapping=True)["offset_mapping"]

    aligned_bio = []
    for token_start, token_end in token_offsets:
        if token_start == token_end:
            # there will be '[CLS]' and '[SEP]' at the begin and the end, so here we jump those. But in the model the '[CLS]' and '[SEP]' are tokens?
            # aligned_bio.append("O")
            continue
        bio_tag = bio_annotations[token_start]
        aligned_bio.append(bio_tag)

    return tokens, aligned_bio

# Main processing
xml_data = parse_xml_files(xml_dir)
print(len(xml_data))
extracted_data = extract_entities(xml_data)
print(len(extracted_data))
linked_treatments = []
for events, timex3s, tlinks, text in extracted_data:
    linked_treatments.extend(link_treatments_with_dates(events, timex3s, tlinks, text))
print(len(linked_treatments))
# Align BIO annotations with BERT tokens for each entry
alignment_annotations=[]
for entry in linked_treatments:
    temp={}
    tokens, bio_time = align_bio_with_tokens(entry['text'], entry['bio_time'], tokenizer)
    _, bio_treatment = align_bio_with_tokens(entry['text'], entry['bio_treatment'], tokenizer)
    temp['tokens'] = tokens
    temp['bio_time_aligned'] = bio_time
    temp['bio_treatment_aligned'] = bio_treatment
    alignment_annotations.append(temp)

# Example output
# for entry in linked_treatments:
#     print("Text:", entry['text'])
#     print("Tokens:", entry['tokens'])
#     print("BIO Time Aligned:", entry['bio_time_aligned'])
#     print("BIO Treatment Aligned:", entry['bio_treatment_aligned'])
#     print("\n")
# linked_treatments[0]
print(len(alignment_annotations))

Error parsing 397.xml, skipping...
Error parsing 527.xml, skipping...
Error parsing 627.xml, skipping...
Error parsing 53.xml, skipping...
Error parsing 687.xml, skipping...
Error parsing 802.xml, skipping...
114
114


Token indices sequence length is longer than the specified maximum sequence length for this model (2061 > 512). Running this sequence through the model will result in indexing errors


2979
2979


bert:[CLS] + {treatmeng} + [SEP] + {tokens}

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizerFast
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import numpy as np

# Define a custom dataset class
class SequenceLabelingDataset(Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokens[idx],
            'labels': self.labels[idx]
        }

# Custom collate function to pad sequences
def collate_fn(batch):
    max_len = max([len(x['input_ids']) for x in batch])

    input_ids = []
    labels = []

    for item in batch:
        input_ids.append(item['input_ids'] + [0] * (max_len - len(item['input_ids'])))
        labels.append(item['labels'] + [0] * (max_len - len(item['labels'])))

    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'labels': torch.tensor(labels, dtype=torch.long)
    }

# Prepare the data
def prepare_data(alignment_annotations, tokenizer, max_len=512):
    label_to_id = {
        "O": 0,
        "B-BEFORE": 1,
        "I-BEFORE": 2,
        "B-AFTER": 3,
        "I-AFTER": 4,
        "B-OVERLAP": 5,
        "I-OVERLAP": 6
    }

    tokens, labels = [], []

    for entry in alignment_annotations:
        modified_tokens = []
        modified_labels = []

        treatment_started = False
        for token, treatment_tag, time_tag in zip(entry['tokens'], entry['bio_treatment_aligned'], entry['bio_time_aligned']):
            if treatment_tag == "B-TREATMENT" and not treatment_started:
                modified_tokens.append("#")
                modified_labels.append("O")
                treatment_started = True

            modified_tokens.append(token)
            modified_labels.append(time_tag)

            if treatment_tag == "I-TREATMENT" and treatment_started:
                next_index = entry['tokens'].index(token) + 1
                if next_index >= len(entry['tokens']) or entry['bio_treatment_aligned'][next_index] != "I-TREATMENT":
                    modified_tokens.append("#")
                    modified_labels.append("O")
                    treatment_started = False

        token_ids = tokenizer.convert_tokens_to_ids(modified_tokens)
        label_ids = [label_to_id.get(tag, 0) for tag in modified_labels]

        if len(token_ids) > max_len:
            token_ids = token_ids[:max_len]
            label_ids = label_ids[:max_len]

        tokens.append(token_ids)
        labels.append(label_ids)

    return tokens, labels

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
test_tokens, test_labels = prepare_data(alignment_annotations, tokenizer)
train_tokens, train_labels = prepare_data(processed_data, tokenizer)
# Split data into training and test sets
# train_tokens, test_tokens, train_labels, test_labels = train_test_split(
#     tokens, labels, test_size=0.2, random_state=42
# )

train_dataset = SequenceLabelingDataset(train_tokens, train_labels)
test_dataset = SequenceLabelingDataset(test_tokens, test_labels)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
unique_labels, counts = np.unique(np.concatenate(train_labels), return_counts=True)

print("Counts for each label:", counts)
# Define the BERT model without additional positional encoding
class BertForSequenceLabeling(nn.Module):
    def __init__(self, num_labels=7):
        super(BertForSequenceLabeling, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids):
        outputs = self.bert(input_ids)
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return logits

# Initialize the model
model = BertForSequenceLabeling(num_labels=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

# Training function
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids)
        loss = criterion(outputs.view(-1, 7), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

# Evaluation function
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids)
            loss = criterion(outputs.view(-1, 7), labels.view(-1))

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=-1).cpu().numpy()
            all_preds.extend(preds.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    return total_loss / len(dataloader), all_preds, all_labels

# Training loop

epochs = 20
for epoch in range(epochs):
    train_loss = train(model, train_dataloader, optimizer, criterion, device)

    val_loss, val_preds, val_labels = evaluate(model, test_dataloader, criterion, device)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(classification_report(val_labels, val_preds, target_names=['O', 'B-BEFORE', 'I-BEFORE', 'B-AFTER', 'I-AFTER', 'B-OVERLAP', 'I-OVERLAP']))

# Output an example of evaluation
sample_input = test_tokens[0]
sample_labels = test_labels[0]

model.eval()
with torch.no_grad():
    input_ids = torch.tensor(sample_input).unsqueeze(0).to(device)
    outputs = model(input_ids)
    preds = torch.argmax(outputs, dim=-1).cpu().numpy().flatten()

print("Sample tokens:", tokenizer.convert_ids_to_tokens(sample_input))
print("Sample true labels:", sample_labels)
print("Sample predictions:", preds)


Counts for each label: [871298   1690   6699    168    547    160    428]


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch 1/20
Train Loss: 0.0697
Validation Loss: 0.0236


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.76      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522704
   macro avg       0.22      0.25      0.23   1522704
weighted avg       0.99      0.99      0.99   1522704

Epoch 2/20
Train Loss: 0.0226
Validation Loss: 0.0198


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.62      0.79      0.70      2722
    I-BEFORE       0.69      0.75      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522704
   macro avg       0.33      0.36      0.34   1522704
weighted avg       0.99      0.99      0.99   1522704

Epoch 3/20
Train Loss: 0.0196
Validation Loss: 0.0198


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.59      0.82      0.69      2722
    I-BEFORE       0.64      0.80      0.71     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522704
   macro avg       0.32      0.37      0.34   1522704
weighted avg       0.99      0.99      0.99   1522704

Epoch 4/20
Train Loss: 0.0182
Validation Loss: 0.0166


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.69      0.86      0.77      2722
    I-BEFORE       0.71      0.83      0.77     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522704
   macro avg       0.34      0.38      0.36   1522704
weighted avg       0.99      0.99      0.99   1522704

Epoch 5/20
Train Loss: 0.0157
Validation Loss: 0.0219


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.73      0.79      0.76      2722
    I-BEFORE       0.74      0.79      0.76     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522704
   macro avg       0.35      0.37      0.36   1522704
weighted avg       0.99      0.99      0.99   1522704

Epoch 6/20
Train Loss: 0.0142
Validation Loss: 0.0219


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.76      0.82      0.79      2722
    I-BEFORE       0.77      0.82      0.79     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           1.00   1522704
   macro avg       0.36      0.38      0.37   1522704
weighted avg       0.99      1.00      1.00   1522704

Epoch 7/20
Train Loss: 0.0121
Validation Loss: 0.0200


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.71      0.79      0.75      2722
    I-BEFORE       0.73      0.79      0.76     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.27      0.01      0.01       490

    accuracy                           1.00   1522704
   macro avg       0.39      0.37      0.36   1522704
weighted avg       0.99      1.00      0.99   1522704

Epoch 8/20
Train Loss: 0.0110
Validation Loss: 0.0203


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.68      0.79      0.73      2722
    I-BEFORE       0.70      0.78      0.74     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.04      0.01      0.01       163
   I-OVERLAP       0.11      0.03      0.05       490

    accuracy                           0.99   1522704
   macro avg       0.36      0.37      0.36   1522704
weighted avg       0.99      0.99      0.99   1522704

Epoch 9/20
Train Loss: 0.0097
Validation Loss: 0.0203


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.79      0.84      0.81      2722
    I-BEFORE       0.79      0.86      0.82     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.23      0.04      0.06       163
   I-OVERLAP       0.55      0.04      0.08       490

    accuracy                           1.00   1522704
   macro avg       0.48      0.40      0.40   1522704
weighted avg       1.00      1.00      1.00   1522704

Epoch 10/20
Train Loss: 0.0079
Validation Loss: 0.0172


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.77      0.84      0.81      2722
    I-BEFORE       0.78      0.85      0.81     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.56      0.06      0.11       163
   I-OVERLAP       0.53      0.06      0.10       490

    accuracy                           1.00   1522704
   macro avg       0.52      0.40      0.40   1522704
weighted avg       1.00      1.00      1.00   1522704

Epoch 11/20
Train Loss: 0.0073
Validation Loss: 0.0220


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.76      0.83      0.79      2722
    I-BEFORE       0.76      0.84      0.80     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.45      0.06      0.11       163
   I-OVERLAP       0.59      0.04      0.08       490

    accuracy                           1.00   1522704
   macro avg       0.51      0.40      0.40   1522704
weighted avg       1.00      1.00      1.00   1522704

Epoch 12/20
Train Loss: 0.0066
Validation Loss: 0.0194


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.77      0.83      0.80      2722
    I-BEFORE       0.78      0.84      0.81     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.29      0.06      0.10       163
   I-OVERLAP       0.41      0.08      0.14       490

    accuracy                           1.00   1522704
   macro avg       0.46      0.40      0.41   1522704
weighted avg       1.00      1.00      1.00   1522704

Epoch 13/20
Train Loss: 0.0062
Validation Loss: 0.0205


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.77      0.84      0.80      2722
    I-BEFORE       0.77      0.84      0.81     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.24      0.07      0.11       163
   I-OVERLAP       0.41      0.10      0.16       490

    accuracy                           1.00   1522704
   macro avg       0.46      0.41      0.41   1522704
weighted avg       1.00      1.00      1.00   1522704

Epoch 14/20
Train Loss: 0.0059
Validation Loss: 0.0199
              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.76      0.81      0.79      2722
    I-BEFORE       0.77      0.81      0.79     10762
     B-AFTER       0.33      0.01      0.03       224
     I-AFTER       0.47      0.01      0.02       856
   B-OVERLAP       0.33

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.78      0.87      0.82      2722
    I-BEFORE       0.79      0.87      0.83     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.18      0.07      0.11       163
   I-OVERLAP       0.28      0.05      0.09       490

    accuracy                           1.00   1522704
   macro avg       0.43      0.41      0.41   1522704
weighted avg       1.00      1.00      1.00   1522704

Epoch 16/20
Train Loss: 0.0047
Validation Loss: 0.0190
              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507487
    B-BEFORE       0.80      0.85      0.82      2722
    I-BEFORE       0.80      0.86      0.83     10762
     B-AFTER       0.25      0.01      0.02       224
     I-AFTER       0.17      0.01      0.01       856
   B-OVERLAP       0.16

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizerFast
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import numpy as np

# Define a custom dataset class
class SequenceLabelingDataset(Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokens[idx],
            'labels': self.labels[idx]
        }

# Custom collate function to pad sequences
def collate_fn(batch):
    max_len = max([len(x['input_ids']) for x in batch])

    input_ids = []
    labels = []

    for item in batch:
        input_ids.append(item['input_ids'] + [0] * (max_len - len(item['input_ids'])))
        labels.append(item['labels'] + [0] * (max_len - len(item['labels'])))

    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'labels': torch.tensor(labels, dtype=torch.long)
    }

# Prepare the data
def prepare_data(alignment_annotations, tokenizer, max_len=512):
    label_to_id = {
        "O": 0,
        "B-BEFORE": 1,
        "I-BEFORE": 2,
        "B-AFTER": 3,
        "I-AFTER": 4,
        "B-OVERLAP": 5,
        "I-OVERLAP": 6
    }
    id_to_label = {v: k for k, v in label_to_id.items()}

    tokens, labels = [], []

    for entry in alignment_annotations:
        modified_tokens = []
        modified_labels = []

        treatment_started = False
        for i, (token, treatment_tag, time_tag) in enumerate(zip(entry['tokens'], entry['bio_treatment_aligned'], entry['bio_time_aligned'])):
            if treatment_tag == "B-TREATMENT" and not treatment_started:
                modified_tokens.append("#")
                modified_labels.append("O")
                treatment_started = True

            modified_tokens.append(token)
            modified_labels.append(time_tag)

            if (treatment_tag == "I-TREATMENT" or treatment_tag == "B-TREATMENT") and treatment_started:
                next_index = i + 1
                if next_index >= len(entry['tokens']) or entry['bio_treatment_aligned'][next_index] not in ["I-TREATMENT", "B-TREATMENT"]:
                    modified_tokens.append("#")
                    modified_labels.append("O")
                    treatment_started = False

        token_ids = tokenizer.convert_tokens_to_ids(modified_tokens)
        label_ids = [label_to_id.get(tag, 0) for tag in modified_labels]

        if len(token_ids) > max_len:
            token_ids = token_ids[:max_len]
            label_ids = label_ids[:max_len]

        tokens.append(token_ids)
        labels.append(label_ids)

    return tokens, labels, id_to_label

# Load data and prepare datasets


tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
# tokens, labels, id_to_label = prepare_data(alignment_annotations, tokenizer)

test_tokens, test_labels = prepare_data(alignment_annotations, tokenizer)
train_tokens, train_labels = prepare_data(processed_data, tokenizer)

# Split data into training and test sets
# train_tokens, test_tokens, train_labels, test_labels = train_test_split(
#     tokens, labels, test_size=0.2, random_state=42
# )

train_dataset = SequenceLabelingDataset(train_tokens, train_labels)
test_dataset = SequenceLabelingDataset(test_tokens, test_labels)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

# Define the BERT model without additional positional encoding
class BertForSequenceLabeling(nn.Module):
    def __init__(self, num_labels=7):
        super(BertForSequenceLabeling, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids):
        outputs = self.bert(input_ids)
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        return logits

# Initialize the model
model = BertForSequenceLabeling(num_labels=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

# Training function
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids)
        loss = criterion(outputs.view(-1, 7), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

# Helper function to extract entities from BIO labels
def get_entities(seq, id_to_label):
    """Gets entities from sequence of label IDs.
    Args:
        seq (list): sequence of label IDs.
    Returns:
        list: list of (chunk_type, chunk_start, chunk_end).
    """
    chunks = []
    chunk_type, chunk_start = None, None

    for i, label_id in enumerate(seq):
        label = id_to_label[label_id]
        if label.startswith("B-"):
            if chunk_type is not None:
                chunks.append((chunk_type, chunk_start, i - 1))
            chunk_type = label.split("-")[1]
            chunk_start = i
        elif label.startswith("I-") and chunk_type is not None:
            continue
        else:
            if chunk_type is not None:
                chunks.append((chunk_type, chunk_start, i - 1))
                chunk_type, chunk_start = None, None

    if chunk_type is not None:
        chunks.append((chunk_type, chunk_start, len(seq) - 1))

    return chunks

# Evaluation function at entity level
def evaluate(model, dataloader, criterion, device, id_to_label):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids)
            loss = criterion(outputs.view(-1, 7), labels.view(-1))

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=-1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    all_true_entities = []
    all_pred_entities = []

    for true_seq, pred_seq in zip(all_labels, all_preds):
        true_entities = get_entities(true_seq, id_to_label)
        pred_entities = get_entities(pred_seq, id_to_label)
        all_true_entities.extend(true_entities)
        all_pred_entities.extend(pred_entities)

    # Only keep entities for the target labels
    target_labels = ['BEFORE', 'AFTER', 'OVERLAP']
    filtered_true_entities = [entity for entity in all_true_entities if entity[0] in target_labels]
    filtered_pred_entities = [entity for entity in all_pred_entities if entity[0] in target_labels]

    true_entity_labels = [entity[0] for entity in filtered_true_entities]
    pred_entity_labels = [entity[0] for entity in filtered_pred_entities]

    report = classification_report(
        true_entity_labels,
        pred_entity_labels,
        labels=target_labels,
        target_names=target_labels
    )

    return total_loss / len(dataloader), report

# Training loop
epochs = 10
for epoch in range(epochs):
    train_loss = train(model, train_dataloader, optimizer, criterion, device)
    val_loss, val_report = evaluate(model, test_dataloader, criterion, device, id_to_label)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(val_report)

# Output an example of evaluation
sample_input = test_tokens[0]
sample_labels = test_labels[0]

model.eval()
with torch.no_grad():
    input_ids = torch.tensor(sample_input).unsqueeze(0).to(device)
    outputs = model(input_ids)
    preds = torch.argmax(outputs, dim=-1).cpu().numpy().flatten()

print("Sample tokens:", tokenizer.convert_ids_to_tokens(sample_input))
print("Sample true labels:", sample_labels)
print("Sample predictions:", preds)


ValueError: too many values to unpack (expected 2)

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizerFast
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import numpy as np

# Define a custom dataset class
class SequenceLabelingDataset(Dataset):
    def __init__(self, tokens, pos_encodings, labels):
        self.tokens = tokens
        self.pos_encodings = pos_encodings
        self.labels = labels

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokens[idx],
            'pos_encodings': self.pos_encodings[idx],
            'labels': self.labels[idx]}

# Custom collate function to pad sequences
def collate_fn(batch):
    max_len = max([len(x['input_ids']) for x in batch])

    input_ids = []
    pos_encodings = []
    labels = []

    for item in batch:
        input_ids.append(item['input_ids'] + [0] * (max_len - len(item['input_ids'])))
        pos_encodings.append(item['pos_encodings'] + [0] * (max_len - len(item['pos_encodings'])))
        labels.append(item['labels'] + [0] * (max_len - len(item['labels'])))

    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'pos_encodings': torch.tensor(pos_encodings, dtype=torch.float),
        'labels': torch.tensor(labels, dtype=torch.long)
    }

# Prepare the data
def prepare_data(alignment_annotations, tokenizer, max_len=512):
    label_to_id = {
        "O": 0,
        "B-BEFORE": 1,
        "I-BEFORE": 2,
        "B-AFTER": 3,
        "I-AFTER": 4,
        "B-OVERLAP": 5,
        "I-OVERLAP": 6
    }

    tokens, pos_encodings, labels = [], [], []

    for entry in alignment_annotations:
        token_ids = tokenizer.convert_tokens_to_ids(entry['tokens'])
        pos_encoding = [1 if tag == "B-TREATMENT" or tag == "I-TREATMENT" else 0 for tag in entry['bio_treatment_aligned']]
        label = [label_to_id.get(tag, 0) for tag in entry['bio_time_aligned']]  # default to 0 if tag not found

        if len(token_ids) > max_len:
            token_ids = token_ids[:max_len]
            pos_encoding = pos_encoding[:max_len]
            label = label[:max_len]

        tokens.append(token_ids)
        pos_encodings.append(pos_encoding)
        labels.append(label)

    return tokens, pos_encodings, labels

# Load data and prepare datasets
# tokens, pos_encodings, labels = prepare_data(alignment_annotations, tokenizer)

# Split data into training and test sets
# train_tokens, test_tokens, train_pos, test_pos, train_labels, test_labels = train_test_split(
#     tokens, pos_encodings, labels, test_size=0.2, random_state=42
# )

test_tokens, test_pos, test_labels = prepare_data(alignment_annotations, tokenizer)
train_tokens, train_pos, train_labels = prepare_data(processed_data, tokenizer)

train_dataset = SequenceLabelingDataset(train_tokens, train_pos, train_labels)
test_dataset = SequenceLabelingDataset(test_tokens, test_pos, test_labels)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

# Define the BERT model with additional input for positional encoding
class BertForSequenceLabeling(nn.Module):
    def __init__(self, num_labels=7):
        super(BertForSequenceLabeling, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size + 1, num_labels)

    def forward(self, input_ids, pos_encodings):
        outputs = self.bert(input_ids)
        sequence_output = outputs.last_hidden_state
        # sequence_output torch.Size([16, 512, 768])
        # pos_encodings torch.Size([16, 512])
        # Concatenate positional encoding

        pos_encodings = pos_encodings.unsqueeze(-1)
        # pos_encodings.unsqueeze(-1) torch.Size([16, 512, 1])
        sequence_output = torch.cat((sequence_output, pos_encodings), dim=-1)
        # torch.cat torch.Size([16, 512, 769])

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        return logits

# Initialize the model
model = BertForSequenceLabeling(num_labels=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

# Training function
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        pos_encodings = batch['pos_encodings'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, pos_encodings)
        loss = criterion(outputs.view(-1, 7), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

# Evaluation function
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            pos_encodings = batch['pos_encodings'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, pos_encodings)
            loss = criterion(outputs.view(-1, 7), labels.view(-1))

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=-1).cpu().numpy()
            all_preds.extend(preds.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    return total_loss / len(dataloader), all_preds, all_labels

# Training loop
epochs = 10
for epoch in range(epochs):
    train_loss = train(model, train_dataloader, optimizer, criterion, device)
    val_loss, val_preds, val_labels = evaluate(model, test_dataloader, criterion, device)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(classification_report(val_labels, val_preds, target_names=['O', 'B-BEFORE', 'I-BEFORE', 'B-AFTER', 'I-AFTER', 'B-OVERLAP', 'I-OVERLAP']))

# Output an example of evaluation
sample_input = test_tokens[0]
sample_pos = test_pos[0]
sample_labels = test_labels[0]

model.eval()
with torch.no_grad():
    input_ids = torch.tensor(sample_input).unsqueeze(0).to(device)
    pos_encodings = torch.tensor(sample_pos).unsqueeze(0).to(device)
    outputs = model(input_ids, pos_encodings)
    preds = torch.argmax(outputs, dim=-1).cpu().numpy().flatten()

print("Sample tokens:", tokenizer.convert_ids_to_tokens(sample_input))
print("Sample true labels:", sample_labels)
print("Sample predictions:", preds)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch 1/10
Train Loss: 0.0695
Validation Loss: 0.0235


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.75      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.22      0.25      0.23   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 2/10
Train Loss: 0.0228
Validation Loss: 0.0212


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.68      0.74      0.71      2722
    I-BEFORE       0.69      0.75      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.34      0.36      0.35   1522512
weighted avg       0.99      0.99      0.99   1522512



KeyboardInterrupt: 

bert: tokens +{treatmeng postional one-hot} -> classifier

In [5]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizerFast
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import numpy as np

# Define a custom dataset class
class SequenceLabelingDataset(Dataset):
    def __init__(self, tokens, pos_encodings, labels):
        self.tokens = tokens
        self.pos_encodings = pos_encodings
        self.labels = labels

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokens[idx],
            'pos_encodings': self.pos_encodings[idx],
            'labels': self.labels[idx]}

# Custom collate function to pad sequences
def collate_fn(batch):
    max_len = max([len(x['input_ids']) for x in batch])

    input_ids = []
    pos_encodings = []
    labels = []

    for item in batch:
        input_ids.append(item['input_ids'] + [0] * (max_len - len(item['input_ids'])))
        pos_encodings.append(item['pos_encodings'] + [0] * (max_len - len(item['pos_encodings'])))
        labels.append(item['labels'] + [0] * (max_len - len(item['labels'])))

    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'pos_encodings': torch.tensor(pos_encodings, dtype=torch.float),
        'labels': torch.tensor(labels, dtype=torch.long)
    }

# Prepare the data
def prepare_data(alignment_annotations, tokenizer, max_len=512):
    label_to_id = {
        "O": 0,
        "B-BEFORE": 1,
        "I-BEFORE": 2,
        "B-AFTER": 3,
        "I-AFTER": 4,
        "B-OVERLAP": 5,
        "I-OVERLAP": 6
    }

    tokens, pos_encodings, labels = [], [], []

    for entry in alignment_annotations:
        token_ids = tokenizer.convert_tokens_to_ids(entry['tokens'])
        pos_encoding = [1 if tag == "B-TREATMENT" or tag == "I-TREATMENT" else 0 for tag in entry['bio_treatment_aligned']]
        label = [label_to_id.get(tag, 0) for tag in entry['bio_time_aligned']]  # default to 0 if tag not found

        if len(token_ids) > max_len:
            token_ids = token_ids[:max_len]
            pos_encoding = pos_encoding[:max_len]
            label = label[:max_len]

        tokens.append(token_ids)
        pos_encodings.append(pos_encoding)
        labels.append(label)

    return tokens, pos_encodings, labels

# Load data and prepare datasets
# tokens, pos_encodings, labels = prepare_data(alignment_annotations, tokenizer)

# Split data into training and test sets
# train_tokens, test_tokens, train_pos, test_pos, train_labels, test_labels = train_test_split(
#     tokens, pos_encodings, labels, test_size=0.2, random_state=42
# )

test_tokens, test_pos, test_labels = prepare_data(alignment_annotations, tokenizer)
train_tokens, train_pos, train_labels = prepare_data(processed_data, tokenizer)


train_dataset = SequenceLabelingDataset(train_tokens, train_pos, train_labels)
test_dataset = SequenceLabelingDataset(test_tokens, test_pos, test_labels)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

# Define the BERT model with additional input for positional encoding
class BertForSequenceLabeling(nn.Module):
    def __init__(self, num_labels=7):
        super(BertForSequenceLabeling, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size + 1, num_labels)

    def forward(self, input_ids, pos_encodings):
        outputs = self.bert(input_ids)
        sequence_output = outputs.last_hidden_state
        # sequence_output torch.Size([16, 512, 768])
        # pos_encodings torch.Size([16, 512])
        # Concatenate positional encoding

        pos_encodings = pos_encodings.unsqueeze(-1)
        # pos_encodings.unsqueeze(-1) torch.Size([16, 512, 1])
        sequence_output = torch.cat((sequence_output, pos_encodings), dim=-1)
        # torch.cat torch.Size([16, 512, 769])

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        return logits

# Initialize the model
model = BertForSequenceLabeling(num_labels=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

# Training function
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        pos_encodings = batch['pos_encodings'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, pos_encodings)
        loss = criterion(outputs.view(-1, 7), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

# Evaluation function
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            pos_encodings = batch['pos_encodings'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, pos_encodings)
            loss = criterion(outputs.view(-1, 7), labels.view(-1))

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=-1).cpu().numpy()
            all_preds.extend(preds.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    return total_loss / len(dataloader), all_preds, all_labels

# Training loop
epochs = 10
for epoch in range(epochs):
    train_loss = train(model, train_dataloader, optimizer, criterion, device)
    val_loss, val_preds, val_labels = evaluate(model, test_dataloader, criterion, device)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(classification_report(val_labels, val_preds, target_names=['O', 'B-BEFORE', 'I-BEFORE', 'B-AFTER', 'I-AFTER', 'B-OVERLAP', 'I-OVERLAP']))

# Output an example of evaluation
sample_input = test_tokens[0]
sample_pos = test_pos[0]
sample_labels = test_labels[0]

model.eval()
with torch.no_grad():
    input_ids = torch.tensor(sample_input).unsqueeze(0).to(device)
    pos_encodings = torch.tensor(sample_pos).unsqueeze(0).to(device)
    outputs = model(input_ids, pos_encodings)
    preds = torch.argmax(outputs, dim=-1).cpu().numpy().flatten()

print("Sample tokens:", tokenizer.convert_ids_to_tokens(sample_input))
print("Sample true labels:", sample_labels)
print("Sample predictions:", preds)

Epoch 1/10
Train Loss: 0.0713
Validation Loss: 0.0235


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.58      0.75      0.65     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.23      0.25      0.24   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 2/10
Train Loss: 0.0222
Validation Loss: 0.0198


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.70      0.74      0.72      2722
    I-BEFORE       0.70      0.75      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.34      0.35      0.35   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 3/10
Train Loss: 0.0195
Validation Loss: 0.0192


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.69      0.75      0.72      2722
    I-BEFORE       0.70      0.75      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.34      0.36      0.35   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 4/10
Train Loss: 0.0185
Validation Loss: 0.0193


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.70      0.71      0.71      2722
    I-BEFORE       0.70      0.74      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.34      0.35      0.35   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 5/10
Train Loss: 0.0182
Validation Loss: 0.0193


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.69      0.70      0.70      2722
    I-BEFORE       0.70      0.67      0.69     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.34      0.34      0.34   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 6/10
Train Loss: 0.0172
Validation Loss: 0.0194


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.68      0.76      0.72      2722
    I-BEFORE       0.70      0.75      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.34      0.36      0.35   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 7/10
Train Loss: 0.0164
Validation Loss: 0.0190


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.65      0.75      0.70      2722
    I-BEFORE       0.68      0.75      0.71     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.33      0.36      0.34   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 8/10
Train Loss: 0.0151
Validation Loss: 0.0194


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.69      0.65      0.67      2722
    I-BEFORE       0.68      0.72      0.70     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.34      0.34      0.34   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 9/10
Train Loss: 0.0142
Validation Loss: 0.0189


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.65      0.74      0.70      2722
    I-BEFORE       0.67      0.75      0.71     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.33      0.36      0.34   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 10/10
Train Loss: 0.0134
Validation Loss: 0.0188


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.68      0.72      0.70      2722
    I-BEFORE       0.69      0.73      0.71     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.34      0.35      0.34   1522512
weighted avg       0.99      0.99      0.99   1522512

Sample tokens: ['admission', 'date', ':', '2014', '-', '11', '-', '29', 'discharge', 'date', ':', '2014', '-', '12', '-', '05', 'service', ':', 'medicine', 'history', 'of', 'present', 'illness', ':', '47', 'yo', 'f', 'w', '/', 'h', '/', 'o', 'ste', '##roid', '-', 'induced', 'hyper', '##gly', '##ce', '##mia', ',', 'sl', '##e', 'w', '/', 'h', '/', 'o', 'per', '##ica', '##rdi', '##tis', ',', 'transverse

  _warn_prf(average, modifier, msg_start, len(result))


bert: tokens +{treatmeng postional one-hot} -> 4FCN + classifier

In [6]:
# Define a custom dataset class
class SequenceLabelingDataset(Dataset):
    def __init__(self, tokens, pos_encodings, labels):
        self.tokens = tokens
        self.pos_encodings = pos_encodings
        self.labels = labels

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokens[idx],
            'pos_encodings': self.pos_encodings[idx],
            'labels': self.labels[idx]}

# Custom collate function to pad sequences
def collate_fn(batch):
    max_len = max([len(x['input_ids']) for x in batch])

    input_ids = []
    pos_encodings = []
    labels = []

    for item in batch:
        input_ids.append(item['input_ids'] + [0] * (max_len - len(item['input_ids'])))
        pos_encodings.append(item['pos_encodings'] + [0] * (max_len - len(item['pos_encodings'])))
        labels.append(item['labels'] + [0] * (max_len - len(item['labels'])))

    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'pos_encodings': torch.tensor(pos_encodings, dtype=torch.float),
        'labels': torch.tensor(labels, dtype=torch.long)
    }

# Prepare the data
def prepare_data(alignment_annotations, tokenizer, max_len=512):
    label_to_id = {
        "O": 0,
        "B-BEFORE": 1,
        "I-BEFORE": 2,
        "B-AFTER": 3,
        "I-AFTER": 4,
        "B-OVERLAP": 5,
        "I-OVERLAP": 6
    }

    tokens, pos_encodings, labels = [], [], []

    for entry in alignment_annotations:
        token_ids = tokenizer.convert_tokens_to_ids(entry['tokens'])
        pos_encoding = [1 if tag == "B-TREATMENT" or tag == "I-TREATMENT" else 0 for tag in entry['bio_treatment_aligned']]
        label = [label_to_id.get(tag, 0) for tag in entry['bio_time_aligned']]  # default to 0 if tag not found

        if len(token_ids) > max_len:
            token_ids = token_ids[:max_len]
            pos_encoding = pos_encoding[:max_len]
            label = label[:max_len]

        tokens.append(token_ids)
        pos_encodings.append(pos_encoding)
        labels.append(label)

    return tokens, pos_encodings, labels

# Load data and prepare datasets
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
tokens, pos_encodings, labels = prepare_data(alignment_annotations, tokenizer)

# Split data into training and test sets
# train_tokens, test_tokens, train_pos, test_pos, train_labels, test_labels = train_test_split(
#     tokens, pos_encodings, labels, test_size=0.2, random_state=42
# )

test_tokens, test_pos, test_labels = prepare_data(alignment_annotations, tokenizer)
train_tokens, train_pos, train_labels = prepare_data(processed_data, tokenizer)


train_dataset = SequenceLabelingDataset(train_tokens, train_pos, train_labels)
test_dataset = SequenceLabelingDataset(test_tokens, test_pos, test_labels)


# Compute sample weights
label_counts = np.bincount([label for sublist in train_labels for label in sublist])
class_weights = 1. / label_counts
sample_weights = []
for label_list in train_labels:
    for label in label_list:
        sample_weights.append(class_weights[label])

# sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# Create dataloaders with sampler
# train_dataloader = DataLoader(train_dataset, batch_size=16, sampler=sampler, collate_fn=collate_fn)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)


train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

# Define the BERT model with additional input for positional encoding and FC layers
class BertForSequenceLabeling(nn.Module):
    def __init__(self, num_labels=7):
        super(BertForSequenceLabeling, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(self.bert.config.hidden_size + 1, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.classifier = nn.Linear(64, num_labels)

    def forward(self, input_ids, pos_encodings):
        outputs = self.bert(input_ids)
        sequence_output = outputs.last_hidden_state
        pos_encodings = pos_encodings.unsqueeze(-1)
        sequence_output = torch.cat((sequence_output, pos_encodings), dim=-1)

        sequence_output = self.dropout(sequence_output)
        sequence_output = torch.relu(self.fc1(sequence_output))
        sequence_output = torch.relu(self.fc2(sequence_output))
        sequence_output = torch.relu(self.fc3(sequence_output))
        sequence_output = torch.relu(self.fc4(sequence_output))
        logits = self.classifier(sequence_output)

        return logits

# Initialize the model
model = BertForSequenceLabeling(num_labels=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

# Training function
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        pos_encodings = batch['pos_encodings'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, pos_encodings)
        loss = criterion(outputs.view(-1, 7), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

# Evaluation function
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            pos_encodings = batch['pos_encodings'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, pos_encodings)
            loss = criterion(outputs.view(-1, 7), labels.view(-1))

            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=-1).cpu().numpy()
            all_preds.extend(preds.flatten())
            all_labels.extend(labels.cpu().numpy().flatten())

    return total_loss / len(dataloader), all_preds, all_labels

# Training loop
epochs = 10
for epoch in range(epochs):
    train_loss = train(model, train_dataloader, optimizer, criterion, device)
    val_loss, val_preds, val_labels = evaluate(model, test_dataloader, criterion, device)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(classification_report(val_labels, val_preds, target_names=['O', 'B-BEFORE', 'I-BEFORE', 'B-AFTER', 'I-AFTER', 'B-OVERLAP', 'I-OVERLAP']))

# Output an example of evaluation
sample_input = test_tokens[0]
sample_pos = test_pos[0]
sample_labels = test_labels[0]

model.eval()
with torch.no_grad():
    input_ids = torch.tensor(sample_input).unsqueeze(0).to(device)
    pos_encodings = torch.tensor(sample_pos).unsqueeze(0).to(device)
    outputs = model(input_ids, pos_encodings)
    preds = torch.argmax(outputs, dim=-1).cpu().numpy().flatten()

print("Sample tokens:", tokenizer.convert_ids_to_tokens(sample_input))
print("Sample true labels:", sample_labels)
print("Sample predictions:", preds)

Epoch 1/10
Train Loss: 0.9490
Validation Loss: 0.0818


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       0.99      1.00      0.99   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.00      0.00      0.00     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.14      0.14      0.14   1522512
weighted avg       0.98      0.99      0.99   1522512

Epoch 2/10
Train Loss: 0.0655
Validation Loss: 0.0429


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       0.99      1.00      0.99   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.00      0.00      0.00     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.14      0.14      0.14   1522512
weighted avg       0.98      0.99      0.99   1522512

Epoch 3/10
Train Loss: 0.0375
Validation Loss: 0.0322


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       0.99      1.00      0.99   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.00      0.00      0.00     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.14      0.14      0.14   1522512
weighted avg       0.98      0.99      0.99   1522512

Epoch 4/10
Train Loss: 0.0307
Validation Loss: 0.0264


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.75      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.22      0.25      0.23   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 5/10
Train Loss: 0.0260
Validation Loss: 0.0246


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.75      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.22      0.25      0.23   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 6/10
Train Loss: 0.0246
Validation Loss: 0.0240


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.74      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.22      0.25      0.23   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 7/10
Train Loss: 0.0244
Validation Loss: 0.0236


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.75      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.22      0.25      0.23   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 8/10
Train Loss: 0.0239
Validation Loss: 0.0235


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.75      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.22      0.25      0.23   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 9/10
Train Loss: 0.0237
Validation Loss: 0.0236


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.73      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.22      0.25      0.23   1522512
weighted avg       0.99      0.99      0.99   1522512

Epoch 10/10
Train Loss: 0.0232
Validation Loss: 0.0240


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1507295
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.75      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1522512
   macro avg       0.22      0.25      0.23   1522512
weighted avg       0.99      0.99      0.99   1522512

Sample tokens: ['admission', 'date', ':', '2014', '-', '11', '-', '29', 'discharge', 'date', ':', '2014', '-', '12', '-', '05', 'service', ':', 'medicine', 'history', 'of', 'present', 'illness', ':', '47', 'yo', 'f', 'w', '/', 'h', '/', 'o', 'ste', '##roid', '-', 'induced', 'hyper', '##gly', '##ce', '##mia', ',', 'sl', '##e', 'w', '/', 'h', '/', 'o', 'per', '##ica', '##rdi', '##tis', ',', 'transverse

  _warn_prf(average, modifier, msg_start, len(result))


bert: tokens +{treatmeng postional one-hot} -> CRF

In [7]:
!pip install pytorch-crf
from torchcrf import CRF

class BertForSequenceLabeling(nn.Module):
    def __init__(self, num_labels=7):
        super(BertForSequenceLabeling, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size + 1, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, pos_encodings, labels=None):
        outputs = self.bert(input_ids)
        sequence_output = outputs.last_hidden_state

        pos_encodings = pos_encodings.unsqueeze(-1)
        sequence_output = torch.cat((sequence_output, pos_encodings), dim=-1)

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        if labels is not None:
            loss = -self.crf(logits, labels, mask=input_ids.ne(0))
            return loss
        else:
            predictions = self.crf.decode(logits, mask=input_ids.ne(0))
            return predictions

# Initialize the model
model = BertForSequenceLabeling(num_labels=7)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

# Training function
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        pos_encodings = batch['pos_encodings'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        loss = model(input_ids, pos_encodings, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

# Evaluation function
def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            pos_encodings = batch['pos_encodings'].to(device)
            labels = batch['labels'].to(device)

            loss = model(input_ids, pos_encodings, labels)
            total_loss += loss.item()

            predictions = model(input_ids, pos_encodings)

            input_mask = input_ids.ne(0)  # Mask to identify non-padding tokens

            for pred, label, mask in zip(predictions, labels.cpu().numpy(), input_mask.cpu().numpy()):
                filtered_pred = [p for p, m in zip(pred, mask) if m]
                filtered_label = [l for l, m in zip(label, mask) if m]

                all_preds.extend(filtered_pred)
                all_labels.extend(filtered_label)

    return total_loss / len(dataloader), all_preds, all_labels

# Training loop
epochs = 10
for epoch in range(epochs):
    train_loss = train(model, train_dataloader, optimizer, device)
    val_loss, val_preds, val_labels = evaluate(model, test_dataloader, device)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")

    print(classification_report(val_labels, val_preds, target_names=['O', 'B-BEFORE', 'I-BEFORE', 'B-AFTER', 'I-AFTER', 'B-OVERLAP', 'I-OVERLAP']))

# Output an example of evaluation
sample_input = test_tokens[0]
sample_pos = test_pos[0]
sample_labels = test_labels[0]

model.eval()
with torch.no_grad():
    input_ids = torch.tensor(sample_input).unsqueeze(0).to(device)
    pos_encodings = torch.tensor(sample_pos).unsqueeze(0).to(device)
    predictions = model(input_ids, pos_encodings)

print("Sample tokens:", tokenizer.convert_ids_to_tokens(sample_input))
print("Sample true labels:", sample_labels)
print("Sample predictions:", predictions)


Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl.metadata (2.4 kB)
Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2
Epoch 1/10
Train Loss: 675.1723
Validation Loss: 195.6593


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.75      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.22      0.25      0.23   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 2/10
Train Loss: 202.1845
Validation Loss: 199.1550


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.00      0.00      0.00      2722
    I-BEFORE       0.56      0.75      0.64     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.22      0.25      0.23   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 3/10
Train Loss: 178.1308
Validation Loss: 170.4568


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.66      0.80      0.72      2722
    I-BEFORE       0.59      0.85      0.69     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.32      0.38      0.34   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 4/10
Train Loss: 159.7342
Validation Loss: 158.9218


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.67      0.78      0.72      2722
    I-BEFORE       0.70      0.75      0.73     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.34      0.36      0.35   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 5/10
Train Loss: 153.7757
Validation Loss: 156.4159


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.64      0.82      0.72      2722
    I-BEFORE       0.70      0.75      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.33      0.37      0.35   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 6/10
Train Loss: 152.1240
Validation Loss: 173.6645


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.70      0.74      0.72      2722
    I-BEFORE       0.69      0.76      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.34      0.36      0.35   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 7/10
Train Loss: 145.3716
Validation Loss: 160.0598


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.69      0.55      0.61      2722
    I-BEFORE       0.71      0.56      0.63     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.34      0.30      0.32   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 8/10
Train Loss: 139.0273
Validation Loss: 160.6238


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.67      0.77      0.71      2722
    I-BEFORE       0.69      0.75      0.72     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.34      0.36      0.35   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 9/10
Train Loss: 133.8508
Validation Loss: 200.7482


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       0.99      1.00      1.00   1474343
    B-BEFORE       0.50      0.59      0.54      2722
    I-BEFORE       0.48      0.47      0.48     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.28      0.29      0.29   1489560
weighted avg       0.99      0.99      0.99   1489560

Epoch 10/10
Train Loss: 129.2791
Validation Loss: 198.2737


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           O       1.00      1.00      1.00   1474343
    B-BEFORE       0.56      0.56      0.56      2722
    I-BEFORE       0.57      0.66      0.61     10762
     B-AFTER       0.00      0.00      0.00       224
     I-AFTER       0.00      0.00      0.00       856
   B-OVERLAP       0.00      0.00      0.00       163
   I-OVERLAP       0.00      0.00      0.00       490

    accuracy                           0.99   1489560
   macro avg       0.30      0.32      0.31   1489560
weighted avg       0.99      0.99      0.99   1489560

Sample tokens: ['admission', 'date', ':', '2014', '-', '11', '-', '29', 'discharge', 'date', ':', '2014', '-', '12', '-', '05', 'service', ':', 'medicine', 'history', 'of', 'present', 'illness', ':', '47', 'yo', 'f', 'w', '/', 'h', '/', 'o', 'ste', '##roid', '-', 'induced', 'hyper', '##gly', '##ce', '##mia', ',', 'sl', '##e', 'w', '/', 'h', '/', 'o', 'per', '##ica', '##rdi', '##tis', ',', 'transverse

  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
# Install the transformers library
!pip install transformers

# Import the necessary modules
from transformers import BertTokenizer

# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define the sentence to be tokenized
sentence = "The patient took insulin after May 6th"

# Tokenize the sentence
tokens = tokenizer.tokenize(sentence)

# Print the tokens
print(tokens)





The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

['the', 'patient', 'took', 'insulin', 'after', 'may', '6th']
