In [None]:
""" Classification Training with Flan-T5-Base on Enriched MTS Dialogue Dataset 
    https://huggingface.co/sarahahatee/t5-dialogue-classification-5"""

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install datasets
!pip install evaluate
!pip install transformers
!pip install numpy
!pip install tensorflow
!pip install -U accelerate
!pip install transformers[torch]

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dill, multiprocess, datasets
Successfully installed datasets-2.16.1 dill-0.3.7 multiprocess-0.70.15
Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m803.0 kB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19 (fr

In [3]:
import accelerate
from datasets import load_dataset, load_metric, concatenate_datasets, DatasetDict
from transformers import (
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
import numpy as np
from sklearn.metrics import classification_report
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
# Load training datasets
train_file_a = "/content/MTS-Dialog-TrainingSet-enriched-vectorized.csv"
# train_file_b = "data/MTS-Dialog-ValidationSet.csv"

# Load testing dataset
# test_file = "data/MTS-Dialog-TestSet-1-MEDIQA-Chat-2023.csv"
test_file = "/content/MTS-Dialog-ValidationSet.csv"

# Load training datasets from CSV files
# train_dataset_a = load_dataset('csv', data_files=train_file_a)['train']
# train_dataset_b = load_dataset('csv', data_files=train_file_b)['train']
combined_train_dataset= load_dataset('csv', data_files=train_file_a)['train']

# Load testing dataset from CSV file
test_dataset = load_dataset('csv', data_files=test_file)['train']

# Combine the two training datasets
# combined_train_dataset = concatenate_datasets([train_dataset_a, train_dataset_b])

# Create dialog_dataset dictionary
dialog_dataset = DatasetDict({
    "train": combined_train_dataset,
    "test": test_dataset
})

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [8]:
# loading flan-t5 instead of plain t5
checkpoint = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = checkpoint

# We prefix our tasks with "classify"
PREFIX_CLASSIFY = "Classify the topic of this dialogue: "

# Define the preprocessing function
def preprocess_function_classify(examples, include_umls_cui_mappings=False):
    """Add prefix to the sentences, tokenize the text, and set the labels"""
    # Conditionally include 'umls_cui_mappings' in the input
    if include_umls_cui_mappings and 'umls_cui_mappings' in examples:
        inputs = [PREFIX_CLASSIFY + doc + " UMLS CUI mappings: " + cui_mappings
                  for doc, cui_mappings in zip(examples["dialogue"], examples["umls_cui_mappings"])]
    else:
        inputs = [PREFIX_CLASSIFY + doc for doc in examples["dialogue"]]

    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    # The "labels" are the tokenized outputs:
    labels = tokenizer(
        text_target=examples["section_header"], max_length=50, truncation=True
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Process the datasets
tokenized_train_dataset = dialog_dataset["train"].map(
    lambda examples: preprocess_function_classify(examples, include_umls_cui_mappings=True),
    batched=True
)
tokenized_test_dataset = dialog_dataset["test"].map(
    preprocess_function_classify,
    batched=True
)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [11]:
# Global Parameters
L_RATE = 1e-4
BATCH_SIZE = 7
PER_DEVICE_EVAL_BATCH = 7
WEIGHT_DECAY = 0.01
SAVE_TOTAL_LIM = 3
NUM_EPOCHS = 20

# Set up training arguments
training_args = Seq2SeqTrainingArguments(
    "t5-dialogue-classification-5",
    evaluation_strategy="epoch",
    learning_rate=L_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,
    weight_decay=WEIGHT_DECAY,
    save_total_limit=SAVE_TOTAL_LIM,
    num_train_epochs=NUM_EPOCHS,
    predict_with_generate=True,
    push_to_hub=True,
)


def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # Decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Convert sequences to strings or keep as-is if already strings
    decoded_preds = [
        str(pred) for pred in decoded_preds
    ]  # Convert to string if necessary
    decoded_labels = [
        str(label) for label in decoded_labels
    ]  # Convert to string if necessary

    # Generate classification report
    report = classification_report(decoded_labels, decoded_preds, digits=4)

    return {"classification_report": report}


data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

len_test_dataset = len(tokenized_test_dataset)
print(f"Length of Test Dataset: {len_test_dataset}")


Length of Test Dataset: 100


In [12]:
trainer.train()

trainer.push_to_hub()

model.to("cpu")

Epoch,Training Loss,Validation Loss,Classification Report
1,No log,0.443867,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.0000 0.0000 0.0000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.0000 0.0000 0.0000 1  FAM/SOCHX 0.8400 0.9545 0.8936 22  GENHX 0.4082 1.0000 0.5797 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.0000 0.0000 0.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.6667 0.5714 0.6154 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.5000 1.0000 0.6667 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.0000 0.0000 0.0000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.0000 0.0000 0.0000 11  SYSTEMS 0.0000 0.0000 0.0000 0  accuracy 0.5900 100  macro avg 0.2578 0.2393 0.2169 100  weighted avg 0.4931 0.5900 0.4943 100
2,No log,0.335224,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.0000 0.0000 0.0000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.0000 0.0000 0.0000 1  FAM/SOCHX 0.9167 1.0000 0.9565 22  GENHX 0.4444 1.0000 0.6154 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.5000 1.0000 0.6667 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7500 0.8571 0.8000 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.5714 1.0000 0.7273 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.0000 0.0000 0.0000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 1.0000 0.1818 0.3077 11  accuracy 0.6500 100  macro avg 0.3591 0.3269 0.2937 100  weighted avg 0.6409 0.6500 0.5711 100
3,0.319300,0.39507,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.5000 0.7500 0.6000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EK G 0.0000 0.0000 0.0000 0  EXAM 0.0000 0.0000 0.0000 1  FAM/SOCHX 0.8800 1.0000 0.9362 22  GENHX 0.6333 0.9500 0.7600 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7500 0.8571 0.8000 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.4444 1.0000 0.6154 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.0000 0.0000 0.0000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.6667 0.1818 0.2857 11  accuracy 0.6700 100  macro avg 0.3750 0.3447 0.3237 100  weighted avg 0.6539 0.6700 0.6160 100
4,0.319300,0.370734,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.2727 0.7500 0.4000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.3333 1.0000 0.5000 1  FAM/SOCHX 0.8800 1.0000 0.9362 22  GENHX 0.8571 0.9000 0.8780 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8750 1.0000 0.9333 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.3750 0.7500 0.5000 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.6667 0.6667 0.6667 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8571 0.5455 0.6667 11  accuracy 0.7400 100  macro avg 0.4559 0.4681 0.4274 100  weighted avg 0.7398 0.7400 0.7139 100
5,0.319300,0.389479,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 0.5000 0.2500 0.3333 4  CC 0.3750 0.7500 0.5000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 1.0000 1.0000 1.0000 1  FAM/SOCHX 0.8462 1.0000 0.9167 22  GENHX 0.9000 0.9000 0.9000 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7778 1.0000 0.8750 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.5000 1.0000 0.6667 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 1.0000 0.6667 0.8000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.7500 100  macro avg 0.4887 0.4727 0.4564 100  weighted avg 0.7418 0.7500 0.7240 100
6,0.088500,0.433422,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.4286 0.7500 0.5455 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.3333 0.5000 0.4000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.5000 1.0000 0.6667 1  FAM/SOCHX 0.9565 1.0000 0.9778 22  GENHX 0.8182 0.9000 0.8571 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7778 1.0000 0.8750 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.4000 1.0000 0.5714 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.6667 0.6667 0.6667 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8000 0.3636 0.5000 11  accuracy 0.7300 100  macro avg 0.4841 0.4840 0.4430 100  weighted avg 0.7513 0.7300 0.7041 100
7,0.088500,0.488939,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 0.0000 0.0000 0.0000 4  CC 0.3333 0.7500 0.4615 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.3333 0.5000 0.4000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.5000 1.0000 0.6667 1  FAM/SOCHX 0.8800 1.0000 0.9362 22  GENHX 0.8500 0.8500 0.8500 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.5000 1.0000 0.6667 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.6667 0.8571 0.7500 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.5714 1.0000 0.7273 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.5000 0.6667 0.5714 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8571 0.5455 0.6667 11  accuracy 0.7300 100  macro avg 0.4496 0.5210 0.4548 100  weighted avg 0.7024 0.7300 0.6938 100
8,0.088500,0.554284,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.4286 0.7500 0.5455 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.5000 0.5000 0.5000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.2500 1.0000 0.4000 1  FAM/SOCHX 0.8800 1.0000 0.9362 22  GENHX 0.7200 0.9000 0.8000 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7000 1.0000 0.8235 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.6000 0.7500 0.6667 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.5000 0.3333 0.4000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8571 0.5455 0.6667 11  accuracy 0.7300 100  macro avg 0.4718 0.4639 0.4269 100  weighted avg 0.7195 0.7300 0.6934 100
9,0.036100,0.59142,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 1.0000 0.5000 0.6667 4  CC 0.5000 0.7500 0.6000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.3333 1.0000 0.5000 1  FAM/SOCHX 0.8800 1.0000 0.9362 22  GENHX 0.8095 0.8500 0.8293 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.5000 1.0000 0.6667 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7778 1.0000 0.8750 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.5714 1.0000 0.7273 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 0.3333 0.3333 0.3333 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8571 0.5455 0.6667 11  accuracy 0.7500 100  macro avg 0.4726 0.5239 0.4704 100  weighted avg 0.7265 0.7500 0.7198 100
10,0.036100,0.660692,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 0.5000 0.2500 0.3333 4  CC 0.3333 0.5000 0.4000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.2500 0.5000 0.3333 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 1.0000 1.0000 1.0000 1  FAM/SOCHX 0.9565 1.0000 0.9778 22  GENHX 0.7037 0.9500 0.8085 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8750 1.0000 0.9333 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.5000 0.7500 0.6000 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.6667 0.6667 0.6667 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8571 0.5455 0.6667 11  accuracy 0.7500 100  macro avg 0.4821 0.4831 0.4693 100  weighted avg 0.7250 0.7500 0.7221 100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Trainer is attempting to log a value of "               precision    recall  f1-score   support

      ALLERGY     1.0000    0.2500    0.4000         4
   ASSESSMENT     1.0000    0.2500    0.4000         4
           CC     0.0000    0.0000    0.0000         4
    DIAGNOSIS     0.0000    0.0000    0.0000         1
  DISPOSITION     0.0000    0.0000    0.0000         2
     EDCOURSE     0.0000    0.0000    0.0000         3
         EXAM     0.0000    0.0000    0.0000         1
    FAM/SOCHX     0.8400    0.9545    0.8936        22
        GENHX     0.4082    1.0000    0.5797        20
        GYNHX     0.0000    0.0000    0.0000         1
      IMAGING     0.000

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

events.out.tfevents.1705271741.06d97e0935bf.977.1:   0%|          | 0.00/12.0k [00:00<?, ?B/s]

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

In [13]:
# test
inputs = [
    "classify this dialogue: Doctor: So are you taking any medications at the moment?\nPatient: yes, I take tirosint every morning.\nDoctor: what dosage?\nPatient:125mg."
]
# inputs = ['classify this dialogue: Doctor: \nPatient: \nDoctor: ']

inputs = tokenizer(inputs, return_tensors="pt")
outputs = model.generate(**inputs)
category = tokenizer.decode(outputs[0])
category



'<pad> MEDICATIONS</s>'