https://colab.research.google.com/drive/1jCkpikz0J2o20FBQmYmAGdiKmJGOMo-o?usp=sharing#scrollTo=T-gy-LxM0yAi

In [1]:
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
checkpoint = "microsoft/biogpt"
from relations import relations
from datasets import DatasetDict, Dataset
import pandas as pd
from tqdm.notebook import trange, tqdm
from labels import get_labels

In [2]:
# load labels for bert_w_ner
additional_tokens, _, _, _ = get_labels(mode='GPT_w_ner')
print(additional_tokens, "\n", additional_tokens)

{'additional_special_tokens': ['[learn1]', '[learn2]', '[learn3]', '[learn4]', '[learn5]', '[learn6]']} 
 {'additional_special_tokens': ['[learn1]', '[learn2]', '[learn3]', '[learn4]', '[learn5]', '[learn6]']}


# load the model

In [3]:
# load the model in 8-bit quantization configuration
# the max length of the input is 1024
model = AutoModelForCausalLM.from_pretrained(checkpoint, 
    # load_in_8bit=True, 
    device_map={'':torch.cuda.current_device()},)

In [4]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [5]:
print_trainable_parameters(model)

trainable params: 346763264 || all params: 346763264 || trainable%: 100.0


# Tokenizer

In [6]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/biogpt")

In [7]:
# adding new tokens to the tokenizer
# since I haven't load the model so I will resize the embedding of the model later]
num_added_toks = tokenizer.add_special_tokens(additional_tokens)
print('We have added', num_added_toks, 'tokens')

# save the tokenizer
tokenizer.save_pretrained("GPT_without_ner/GPT_w_ner_tokenizer")

We have added 6 tokens


('GPT_without_ner/GPT_w_ner_tokenizer/tokenizer_config.json',
 'GPT_without_ner/GPT_w_ner_tokenizer/special_tokens_map.json',
 'GPT_without_ner/GPT_w_ner_tokenizer/vocab.json',
 'GPT_without_ner/GPT_w_ner_tokenizer/merges.txt',
 'GPT_without_ner/GPT_w_ner_tokenizer/added_tokens.json')

In [8]:
model.resize_token_embeddings(len(tokenizer))

Embedding(42390, 1024)

# PEFT

Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in float32 for stability.

We also cast the output of the last layer and embedding layer in float32 for the same reasons.

In [9]:
for param in model.parameters():
  param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
  def forward(self, x): return super().forward(x).to(torch.float32)

model.biogpt.embed_tokens = CastOutputToFloat(model.biogpt.embed_tokens)
model.output_projection = CastOutputToFloat(model.output_projection)

In [10]:
# more with LoRAconfig: https://huggingface.co/docs/peft/conceptual_guides/lora

from peft import get_peft_config, get_peft_model, LoraConfig, TaskType, PeftType

peft_config = LoraConfig(
    # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters.
    r=16,
    # alpha: LoRA scaling factor.
    lora_alpha=32, 
    # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
    target_modules=["q_proj", "v_proj"],
    fan_in_fan_out=True,
    lora_dropout=0.05,
    bias="none", 
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, peft_config)
print_trainable_parameters(model)


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/tian/mambaforge/envs/BioRED/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121.so
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 121
CUDA SETUP: Loading binary /home/tian/mambaforge/envs/BioRED/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)


trainable params: 1572864 || all params: 348342272 || trainable%: 0.4515283175278825


In [11]:
# make model's embed_tokens layer also trainable

model.biogpt.embed_tokens[0].weight.requires_grad = True
model.output_projection[0].weight.requires_grad = True

print_trainable_parameters(model)

trainable params: 44980224 || all params: 348342272 || trainable%: 12.912651611803232


In [12]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): BioGptForCausalLM(
      (biogpt): BioGptModel(
        (embed_tokens): CastOutputToFloat(
          (0): Embedding(42390, 1024)
        )
        (embed_positions): BioGptLearnedPositionalEmbedding(1026, 1024)
        (layers): ModuleList(
          (0-23): 24 x BioGptDecoderLayer(
            (self_attn): BioGptAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(
                in_features=1024, out_features=1024, bias=True
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1024, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=1024, bias=False)
                )
                (lora_embedding_A): Pa

In [13]:
# for model, print the layer's name if the layer is trainable, and print the precision of the layer

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape, param.dtype)

base_model.model.biogpt.embed_tokens.0.weight torch.Size([42390, 1024]) torch.float32
base_model.model.biogpt.layers.0.self_attn.v_proj.lora_A.default.weight torch.Size([16, 1024]) torch.float32
base_model.model.biogpt.layers.0.self_attn.v_proj.lora_B.default.weight torch.Size([1024, 16]) torch.float32
base_model.model.biogpt.layers.0.self_attn.q_proj.lora_A.default.weight torch.Size([16, 1024]) torch.float32
base_model.model.biogpt.layers.0.self_attn.q_proj.lora_B.default.weight torch.Size([1024, 16]) torch.float32
base_model.model.biogpt.layers.1.self_attn.v_proj.lora_A.default.weight torch.Size([16, 1024]) torch.float32
base_model.model.biogpt.layers.1.self_attn.v_proj.lora_B.default.weight torch.Size([1024, 16]) torch.float32
base_model.model.biogpt.layers.1.self_attn.q_proj.lora_A.default.weight torch.Size([16, 1024]) torch.float32
base_model.model.biogpt.layers.1.self_attn.q_proj.lora_B.default.weight torch.Size([1024, 16]) torch.float32
base_model.model.biogpt.layers.2.self_attn

# pre-process the text

In [14]:
from data_preprocessing import make_GPT_re_data, GPT_w_ner_preprocess_function

# from data_preprocessing import all_line_of_pmid, get_original_text, get_identifier_and_entity, reorder_list, get_relations

In [15]:
# train and valid file paths
train_file_path = 'data/BioRED/processed/train.tsv'
valid_file_path = 'data/BioRED/processed/dev.tsv'

In [33]:
# make bert_re data
train_data_raw = make_GPT_re_data(file_path=train_file_path, lower=True)
valid_data_raw = make_GPT_re_data(file_path=valid_file_path, lower=True)

Dropped 8 line:
 [6646, 6758, 6776, 6866, 10222, 11775, 18818, 21689]
Dropped 8 line:
 [941, 2220, 2233, 2261, 5335, 5337, 5378, 5490]


In [34]:
relation_static = {k.lower(): 0 for k in relations}

In [35]:
for relation in train_data_raw['relation']:
    relation_static[relation] += 1

for k, v in relation_static.items():
    print(f"{k}: {v}")

none: 18720
association: 2183
bind: 60
comparison: 28
conversion: 3
cotreatment: 31
drug_interaction: 11
negative_correlation: 763
positive_correlation: 1088


In [36]:
# save the raw data
# import json

# with open('GPT_w_ner/data/train_data_dict.json', 'w') as f:
#     json.dump(train_data_raw, f)

# with open('GPT_w_ner/data/valid_data_dict.json', 'w') as f:
#     json.dump(valid_data_raw, f)

In [41]:
import json

with open('GPT_w_ner/data/train_data_dict.json', 'r') as f:
    train_data_raw = json.load(f)

with open('GPT_w_ner/data/valid_data_dict.json', 'r') as f:
    valid_data_raw = json.load(f)

print(train_data_raw.keys())
for k, v in train_data_raw.items():
    print(k, len(v))

# # make into Dataset type
train_data_raw = Dataset.from_dict(train_data_raw)
valid_data_raw = Dataset.from_dict(valid_data_raw)

dict_keys(['pmids', 'text', 'entities', 'outputs', 'relation'])
pmids 22887
text 22887
entities 22887
outputs 22887
relation 22887


In [43]:
from torch.utils.data import Subset
"""
for the train_dataset:
{'[None]': 18720,
 '[Association]': 2183,
 '[Bind]': 60,
 '[Comparison]': 28,
 '[Conversion]': 3,
 '[Cotreatment]': 31,
 '[Drug_Interaction]': 11,
 '[Negative_Correlation]': 763,
 '[Positive_Correlation]': 1088}

so it is neccessary to balance the dataset, we randomly choose 3000 samples from the [None] class with the seed 42
"""
import random
random.seed(42)

# get the index of the [None] class of the datasets type of train_data_raw
none_index = [i for i, example in enumerate(train_data_raw) if example['relation'] == 'none']

# randomly choose 18720-3000 samples from the [None] class
none_index = random.sample(none_index, 18720-3000)
keep_indices = [i for i in range(len(train_data_raw)) if i not in none_index]

# delete the [None] class samples from the train_data_raw
train_data_raw_balanced = train_data_raw.select(keep_indices)

In [45]:
# train_data_raw_balanced

Dataset({
    features: ['pmids', 'text', 'entities', 'outputs', 'relation'],
    num_rows: 7167
})

In [21]:
"""train_data_raw_balanced[0]"""

'train_data_raw_balanced[0]'

In [47]:
# dataset = DatasetDict({
#     "train": train_data_raw_balanced,
#     "valid": valid_data_raw
# })

In [48]:
# tokenized_datasets = dataset.map(lambda example: GPT_w_ner_preprocess_function(example, tokenizer, mode="gpt_w_ner"), batched=True, remove_columns=['pmids', 'text', 'entities', 'outputs', 'relation'])

Map:   0%|          | 0/7167 [00:00<?, ? examples/s]

the relation between the source entity 1 and the target entity 2 is None .
the relation between source entity 1 and target entity 2 is Association .
the relation between the source entity 1 and the target entity 2 is None .
the relation between source entity 1 and target entity 2 is Association .
the relation between the source entity 1 and the target entity 2 is None .
the relation between source entity 1 and target entity 2 is Positive_Correlation .
the relation between the source entity 1 and the target entity 2 is None .
the relation between source entity 1 and target entity 2 is Bind .
the relation between source entity 1 and target entity 2 is Positive_Correlation .
the relation between source entity 1 and target entity 2 is Positive_Correlation .
the relation between source entity 1 and target entity 2 is Association .
the relation between the source entity 1 and the target entity 2 is None .
the relation between source entity 1 and target entity 2 is Association .
the relation 

Map:   0%|          | 0/6650 [00:00<?, ? examples/s]

the relation between source entity 1 and target entity 2 is Positive_Correlation .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation between the source entity 1 and the target entity 2 is None .
the relation betw

In [49]:
# tokenized_datasets.save_to_disk('GPT_w_ner/data/tokenized_dataset_w_ner')

Saving the dataset (0/1 shards):   0%|          | 0/7167 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/6650 [00:00<?, ? examples/s]

In [50]:
from datasets import load_from_disk

tokenized_datasets = load_from_disk('GPT_w_ner/data/tokenized_dataset_w_ner')

tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 7167
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 6650
    })
})

In [51]:
# to tensor
tokenized_datasets.set_format(type='torch', columns=['input_ids'])

In [52]:
tokenizer.decode(tokenized_datasets['train']['input_ids'][0])

'hepatocyte nuclear factor-6: associations between genetic variability and type ii diabetes and between genetic variability and estimates of insulin secretion. the transcription factor hepatocyte nuclear factor (hnf) -6 is an upstream regulator of several genes involved in the pathogenesis of maturity-onset diabetes of the young. we therefore tested the hypothesis that variability in the hnf-6 gene is associated with subsets of type ii (non-insulin-dependent) diabetes mellitus and estimates of insulin secretion in glucose tolerant subjects. we cloned the coding region as well as the intron-exon boundaries of the hnf-6 gene. w e then examined them on genomic dna in six mody probands without mutations in the mody1, mody3 and mody4 genes and in 54 patients with late-onset type ii diabetes by combined single strand conformational polymorphism-heteroduplex analysis followed by direct sequencing of identified variants. an identified missense variant was examined in association studies and ge

# Training

wandb

In [53]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="GPT2",
    # notes="PubmedBERT-FT-NER_w_NERin_10epochs",
    name="BioGPT_w_ner_epoch_5_balanced_train_data_no_[]",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m309439737[0m ([33mtian1995[0m). Use [1m`wandb login --relogin`[0m to force relogin


training

In [56]:
from transformers import DataCollatorForLanguageModeling

In [57]:
import transformers

trainer = transformers.Trainer(
    model=model, 
    train_dataset=tokenized_datasets['train'],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=8, 
        gradient_accumulation_steps=8,
        warmup_steps=1000, 
        num_train_epochs=15,
        learning_rate=2e-4, 
        fp16=True,
        logging_steps=1, 
        report_to="wandb",
        save_strategy="epoch",
        output_dir='GPT_w_ner'
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()



  0%|          | 0/1680 [00:00<?, ?it/s]

{'loss': 3.2441, 'learning_rate': 2.0000000000000002e-07, 'epoch': 0.01}
{'loss': 3.2454, 'learning_rate': 4.0000000000000003e-07, 'epoch': 0.02}
{'loss': 3.2134, 'learning_rate': 6.000000000000001e-07, 'epoch': 0.03}
{'loss': 3.3333, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.04}
{'loss': 3.2727, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.04}
{'loss': 3.2727, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.05}
{'loss': 3.1926, 'learning_rate': 1.4000000000000001e-06, 'epoch': 0.06}
{'loss': 3.1918, 'learning_rate': 1.6000000000000001e-06, 'epoch': 0.07}
{'loss': 3.1955, 'learning_rate': 1.8e-06, 'epoch': 0.08}
{'loss': 3.1896, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.09}
{'loss': 3.3858, 'learning_rate': 2.2e-06, 'epoch': 0.1}
{'loss': 3.2057, 'learning_rate': 2.4000000000000003e-06, 'epoch': 0.11}
{'loss': 3.249, 'learning_rate': 2.6e-06, 'epoch': 0.12}
{'loss': 3.2322, 'learning_rate': 2.8000000000000003e-06, 'epoch': 0.12}
{'loss': 3.2265, 'learning_

TrainOutput(global_step=1680, training_loss=0.8302740742674186, metrics={'train_runtime': 34351.9508, 'train_samples_per_second': 3.13, 'train_steps_per_second': 0.049, 'train_loss': 0.8302740742674186, 'epoch': 15.0})

In [58]:
import wandb
wandb.finish()
trainer.save_model("GPT_w_ner/models/GPT_w_ner_epoch_15_balanced")

0,1
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,▁▁▂▂▂▃▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇███▇▇▆▆▅▅▅▄▄▃▃▂▂▁▁
train/loss,██▇▆▆▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁
train/train_samples_per_second,▁
train/train_steps_per_second,▁

0,1
train/epoch,15.0
train/global_step,1680.0
train/learning_rate,0.0
train/loss,0.1781
train/total_flos,2.0071882986356736e+17
train/train_loss,0.83027
train/train_runtime,34351.9508
train/train_samples_per_second,3.13
train/train_steps_per_second,0.049


In [61]:
model.save_pretrained("GPT_w_ner/models/GPT_w_ner_epoch_15_balanced.peft")

In [62]:
# Since there are key-unmatches in the trainer.save_model(), we need to rename the keys and load the paras in the model

embed_tokens_state_dict = torch.load("GPT_w_ner/models/GPT_w_ner_epoch_15_balanced/pytorch_model.bin")

old_keys = ["base_model.model.biogpt.embed_tokens.0.weight", "base_model.model.output_projection.0.weight"]
new_keys = ["base_model.model.biogpt.embed_tokens.weight", "base_model.model.output_projection.weight"]

for old_key, new_key in zip(old_keys, new_keys):
    # Get the value of the old key
    value = embed_tokens_state_dict[old_key]

    # Create a new key-value pair with the updated name
    embed_tokens_state_dict[new_key] = value

    # Delete the old key if desired
    del embed_tokens_state_dict[old_key]

torch.save(embed_tokens_state_dict, "GPT_w_ner/models/GPT_w_ner_epoch_15_balanced/pytorch_model-af.bin")

In [35]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): BioGptForCausalLM(
      (biogpt): BioGptModel(
        (embed_tokens): CastOutputToFloat(
          (0): Embedding(42390, 1024)
        )
        (embed_positions): BioGptLearnedPositionalEmbedding(1026, 1024)
        (layers): ModuleList(
          (0-23): 24 x BioGptDecoderLayer(
            (self_attn): BioGptAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(
                in_features=1024, out_features=1024, bias=True
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1024, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=1024, bias=False)
                )
                (lora_embedding_A): Pa

# load model and inference

In [107]:
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "microsoft/biogpt"

peft_model_id = "GPT_w_ner/models/GPT_w_ner_epoch_15_balanced.peft"
# config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained("GPT_w_ner/GPT_w_ner_tokenizer")

# resize the token embeddings to match the tokenizer
model.resize_token_embeddings(len(tokenizer))

# Load the Lora model
# the resized embedding layer are still uncorrected, need to load the weights manually
model = PeftModel.from_pretrained(model, peft_model_id)


In [66]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): BioGptForCausalLM(
      (biogpt): BioGptModel(
        (embed_tokens): Embedding(42390, 1024)
        (embed_positions): BioGptLearnedPositionalEmbedding(1026, 1024)
        (layers): ModuleList(
          (0-23): 24 x BioGptDecoderLayer(
            (self_attn): BioGptAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(
                in_features=1024, out_features=1024, bias=True
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1024, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embeddin

In [108]:
model.load_state_dict(torch.load("GPT_w_ner/models/GPT_w_ner_epoch_15_balanced/pytorch_model-af.bin"))

<All keys matched successfully>

In [65]:
model.eval()
model.to("cpu")
inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")

with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])

Tweet text: @ HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label: the right drug is the right patient, the right


In [66]:
import pandas as pd
import re
from tqdm.notebook import trange, tqdm
from torch import nn
from labels import get_labels
from relations import relations
from datasets import DatasetDict, Dataset

from data_preprocessing import make_GPT_re_data, GPT_w_ner_preprocess_function
additional_tokens, _, _, _ = get_labels(mode='GPT_w_ner')


In [111]:
import json

# load test data and preprocess
# test_file_path = 'data/BioRED/processed/test.tsv'
# test_data = make_GPT_re_data(file_path=test_file_path, lower=True)
# # save the raw data

# with open('GPT_w_ner/data/test_data_dict.json', 'w') as f:
#     # json.dump(test_data, f)
# with open('GPT_w_ner/data/test_data_dict.json', 'r') as f:
#     test_data= json.load(f)

# test_dataset_raw = Dataset.from_dict(test_data)

# test_dataset = test_dataset_raw.map(lambda example: GPT_w_ner_preprocess_function(example, tokenizer, infer=True), batched=True, remove_columns=['pmids', 'text', 'entities', 'outputs', 'relation'])

# test_dataset.save_to_disk('GPT_w_ner/data/test_tokenized_dataset_no_ner')



from datasets import load_from_disk

test_dataset = load_from_disk('GPT_w_ner/data/test_tokenized_dataset_no_ner')

test_dataset.set_format(type='torch', columns=['input_ids'])

In [81]:
relation_static = {k.lower(): 0 for k in relations}

for i in range(len(data)):
    relation_static[data.iloc[i, 8].lower()] += 1

print(relation_static)

{'none': 6428, 'association': 635, 'bind': 9, 'comparison': 6, 'conversion': 1, 'cotreatment': 14, 'drug_interaction': 2, 'negative_correlation': 171, 'positive_correlation': 325}


In [112]:
tokenizer.batch_decode(test_dataset['input_ids'][:80])

['a novel scn5a mutation manifests as a malignant form of long qt syndrome with perinatal onset of tachycardia / bradycardia. objective: congenital long qt syndrome (lqts) with in utero onset of the rhythm disturbances is associated with a poor prognosis. in this study we investigated a newborn patient with fetal bradycardia, 2: 1 atrioventricular block and ventricular tachycardia soon after birth. methods: mutational analysis and dna sequencing were conducted in a newborn. the 2: 1 atrioventricular block improved to 1: 1 conduction only after intravenous lidocaine infusion or a high dose of mexiletine, which also controlled the ventricular tachycardia. results: a novel, spontaneous lqts-3 mutation was identified in the transmembrane segment 6 of domain iv of the na (v) 1.5 cardiac sodium channel, with a g-- > a substitution at codon 1763, which changed a valine (gtg) to a methionine (atg). the proband was heterozygous but the mutation was absent in the parents and the sister. expressi

In [114]:
from tqdm.notebook import trange, tqdm
import torch


model.eval()
outputs = []
model.to("cuda")
with torch.no_grad():
    for i in tqdm(range(len(test_dataset))):
    # for i in range(1):
        output = model.generate(input_ids=test_dataset[i]["input_ids"].unsqueeze(0).to("cuda"), max_new_tokens=50, eos_token_id=tokenizer.eos_token_id)
        output_text = tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=False)[0]
        try:
            outputs.append(output_text.split("[learn6]")[1].strip())
        except:
            outputs.append(output_text.strip())
        if i % 10 == 0:
            print(outputs[-1])

    # print(tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=False)[0])

  0%|          | 0/7590 [00:00<?, ?it/s]

the relation between the source entity 1 and the target entity 2 is None. </s>
the relation between source entity 1 and target entity 2 is Association. </s>
the relation between the source entity 1 and the target entity 2 is None. </s>
the relation between source entity 1 and target entity 2 is Association. </s>
the relation between source entity 1 and target entity 2 is Association. </s>
the relation between the source entity 1 and the target entity 2 is None. </s>
the relation between source entity 1 and target entity 2 is Association. </s>
the relation between the source entity 1 and the target entity 2 is None. </s>
the relation between the source entity 1 and the target entity 2 is None. </s>
the relation between source entity 1 and target entity 2 is Association. </s>
the relation between the source entity 1 and the target entity 2 is None. </s>
the relation between the source entity 1 and the target entity 2 is None. </s>
the relation between the source entity 1 and the target e

In [115]:
test_dataset['labels'][30:80]

['the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between source entity 1 and target entity 2 is Association .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between source entity 1 and target entity 2 is Positive_Correlation .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between source entity 1 and target entity 2 is Negative_Correlation .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source ent

In [90]:
test_dataset['labels'][30:80]

['the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between source entity 1 and target entity 2 is Association .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between source entity 1 and target entity 2 is Positive_Correlation .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between source entity 1 and target entity 2 is Negative_Correlation .',
 'the relation between the source entity 1 and the target entity 2 is None .',
 'the relation between the source ent

In [85]:
with open('GPT_w_ner/data/test_data_dict.json', 'r') as f:
    test_data= json.load(f)

test_dataset_raw = Dataset.from_dict(test_data)

In [86]:
test_dataset_raw[0]

{'pmids': '15485686',
 'text': 'a novel scn5a mutation manifests as a malignant form of long qt syndrome with perinatal onset of tachycardia/bradycardia . objective : congenital long qt syndrome ( lqts ) with in utero onset of the rhythm disturbances is associated with a poor prognosis . in this study we investigated a newborn patient with fetal bradycardia , 2:1 atrioventricular block and ventricular tachycardia soon after birth . methods : mutational analysis and dna sequencing were conducted in a newborn . the 2:1 atrioventricular block improved to 1:1 conduction only after intravenous lidocaine infusion or a high dose of mexiletine , which also controlled the ventricular tachycardia . results : a novel , spontaneous lqts-3 mutation was identified in the transmembrane segment 6 of domain iv of the na(v)1.5 cardiac sodium channel , with a g-->a substitution at codon 1763 , which changed a valine ( gtg ) to a methionine ( atg ) . the proband was heterozygous but the mutation was absen

In [97]:
outputs[0]

'the relation between the source entity 1 and the target entity 2 is None. </s>'

post-processing and evaluation

In [116]:
# post processing for the outputs w ner
# (source, target, relation)
# (2, 1, relation)
pairs = []
count = 0
for output in outputs:
    # if the output doesn't end with "<|endoftext|>", find the lastest ";" of the output and only take the previous part
    source = output.split(" source ")[1].strip()
    source = source.split(" target ")[0].strip()
    source = source.split("entity")[1].strip()
    source = source.split("and")[0].strip()

    target = output.split(" target ")[1].strip()
    target = target.split(" is ")[0].strip()
    target = target.split("entity")[1].strip()

    relation = output.split(" is ")[-1].strip()
    relation = relation.split(". </s>")[0].lower().strip()

    pairs.append((source, target, relation))

output_pairs = pairs

print(f"{count} / {len(outputs)}")

0 / 7590


In [117]:
# post processing for the outputs w ner
# (source, target, relation)
# (2, 1, relation)
pairs = []
count = 0
for i, output in enumerate(test_dataset_raw['outputs']):
    # if the output doesn't end with "<|endoftext|>", find the lastest ";" of the output and only take the previous part
    output = output.lower()
    source = output.split(" source ")[1].strip()
    source = source.split(" target ")[0].strip()
    source = source.split("entity")[1].strip()
    source = source.split("and")[0].strip()

    target = output.split(" target ")[1].strip()
    target = target.split(" is ")[0].strip()
    target = target.split("entity")[1].strip()

    relation = test_dataset_raw['relation'][i].lower()

    pairs.append((source, target, relation))

label_pairs = pairs

print(f"{count} / {len(outputs)}")

0 / 7590


In [118]:
result = {
    "output": [],
    "label": []
}

for output, label in zip(output_pairs, label_pairs):
    result['output'].append(output)
    result['label'].append(label)

In [90]:
count = 0
for label in result['label']:
    if label[2] != 'none':
        count += 1

print(count)

1163


In [91]:
# save the result dictionary
import pickle
with open("GPT_w_ner/result/epoch_15_result.pkl", "wb") as f:
    pickle.dump(result, f)

In [92]:
print(f'the length: {len(result["output"])}, {len(result["label"])}')
print(f'instance:\n{result["output"][0]}\n{result["label"][0]}')

the length: 7590, 7590
instance:
('1', '2', 'association')
('1', '2', 'none')


In [119]:
# source and target, relation
st_tp = 0
st_fp = 0
st_fn = 0
st_tn = 0

r_tp = 0
r_fp = 0
r_fn = 0
r_tn = 0

tuple_tp = 0
tuple_fp = 0  
tuple_fn = 0
tuple_tn = 0


for output, label in zip(result['output'], result['label']):
    pair = False
    relation = False
    if output[0] == label[0] and output[1] == label[1]:
        st_tp += 1
        pair = True
    else:
        st_fn += 1
        st_fp += 1
    
    if output[2] == label[2]:
        r_tp += 1
        relation = True
    else:
        r_fn += 1
        r_fp += 1

    if pair and relation:
        tuple_tp += 1
    else:
        tuple_fn += 1
        tuple_fp += 1

In [95]:
# calculate the precision, recall and f1 score

# for source and target
st_precision = st_tp / (st_tp + st_fp)
st_recall = st_tp / (st_tp + st_fn)
st_f1 = 2 * st_precision * st_recall / (st_precision + st_recall)
print(f"source and target precision: {st_precision}, recall: {st_recall}, f1: {st_f1}")

# for relation
r_precision = r_tp / (r_tp + r_fp)
r_recall = r_tp / (r_tp + r_fn)
r_f1 = 2 * r_precision * r_recall / (r_precision + r_recall)
print(f"relation precision: {r_precision}, recall: {r_recall}, f1: {r_f1}")

# for tuple
tuple_precision = tuple_tp / (tuple_tp + tuple_fp)
tuple_recall = tuple_tp / (tuple_tp + tuple_fn)
tuple_f1 = 2 * tuple_precision * tuple_recall / (tuple_precision + tuple_recall)
print(f"tuple precision: {tuple_precision}, recall: {tuple_recall}, f1: {tuple_f1}")

source and target precision: 0.9982872200263505, recall: 0.9982872200263505, f1: 0.9982872200263505
relation precision: 0.5562582345191041, recall: 0.5562582345191041, f1: 0.5562582345191041
tuple precision: 0.555467720685112, recall: 0.555467720685112, f1: 0.555467720685112


In [101]:
relation_static = {k.lower(): 0 for k in relations}
label_relation_static = {k.lower(): 0 for k in relations}

In [102]:
# source and target, relation
st_tp = 0
st_fp = 0
st_fn = 0
st_tn = 0

r_tp = 0
r_fp = 0
r_fn = 0
r_tn = 0

tuple_tp = 0
tuple_fp = 0  
tuple_fn = 0
tuple_tn = 0


for output, label in zip(result['output'], result['label']):
    pair = False
    relation = False
    label_relation_static[label[2]] += 1
    if output[0] == label[0] and output[1] == label[1]:
        st_tp += 1
        pair = True
    else:
        st_fn += 1
        st_fp += 1
    
    if output[2] == label[2]:
        r_tp += 1
        relation = True
    else:
        r_fn += 1
        r_fp += 1

    if pair and relation:
        tuple_tp += 1
        relation_static [output[2]] += 1
    else:
        tuple_fn += 1
        tuple_fp += 1

In [103]:
relation_static

{'none': 3793,
 'association': 423,
 'bind': 0,
 'comparison': 0,
 'conversion': 0,
 'cotreatment': 0,
 'drug_interaction': 0,
 'negative_correlation': 0,
 'positive_correlation': 0}

In [104]:
label_relation_static

{'none': 6427,
 'association': 635,
 'bind': 9,
 'comparison': 6,
 'conversion': 1,
 'cotreatment': 14,
 'drug_interaction': 2,
 'negative_correlation': 171,
 'positive_correlation': 325}