In [None]:
""" Classification Training with Flan-T5-Small on MTS Dialogue Dataset """

In [None]:
!pip install datasets
!pip install evaluate
!pip install transformers
!pip install numpy
!pip install tensorflow
!pip install -U accelerate
!pip install transformers[torch]

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dill, multiprocess, datasets
Successfully installed datasets-2.16.1 dill-0.3.7 multiprocess-0.70.15
Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19 (from

In [None]:
import accelerate
from datasets import load_dataset, load_metric, concatenate_datasets, DatasetDict
from transformers import (
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
import numpy as np
from sklearn.metrics import classification_report
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# Load training datasets
train_file_a = "data/MTS-Dialog-TrainingSet.csv"
# train_file_b = "data/MTS-Dialog-ValidationSet.csv"

# Load testing dataset
# test_file = "data/MTS-Dialog-TestSet-1-MEDIQA-Chat-2023.csv"
test_file = "data/MTS-Dialog-ValidationSet.csv"

# Load training datasets from CSV files
# train_dataset_a = load_dataset('csv', data_files=train_file_a)['train']
# train_dataset_b = load_dataset('csv', data_files=train_file_b)['train']
combined_train_dataset= load_dataset('csv', data_files=train_file_a)['train']

# Load testing dataset from CSV file
test_dataset = load_dataset('csv', data_files=test_file)['train']

# Combine the two training datasets
# combined_train_dataset = concatenate_datasets([train_dataset_a, train_dataset_b])

# Create dialog_dataset dictionary
dialog_dataset = DatasetDict({
    "train": combined_train_dataset,
    "test": test_dataset
})


In [None]:
# loading flan-t5 instead of plain t5
checkpoint = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = checkpoint

# We prefix our tasks with "classify"
PREFIX_CLASSIFY = "Classify the topic of this dialogue: "


# Define the preprocessing function
def preprocess_function_classify(examples):
    """Add prefix to the sentences, tokenize the text, and set the labels"""
    # The "inputs" are the tokenized answer:
    inputs = [PREFIX_CLASSIFY + doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)

    # The "labels" are the tokenized outputs:
    labels = tokenizer(
        text_target=examples["section_header"], max_length=50, truncation=True
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


# Map the preprocessing function across our dataset
tokenized_dataset_classify = dialog_dataset.map(
    preprocess_function_classify, batched=True
)

Map:   0%|          | 0/1201 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [None]:
# Global Parameters
L_RATE = 1e-4
BATCH_SIZE = 7
PER_DEVICE_EVAL_BATCH = 7
WEIGHT_DECAY = 0.01
SAVE_TOTAL_LIM = 3
NUM_EPOCHS = 20

# Set up training arguments
training_args = Seq2SeqTrainingArguments(
    "t5-dialogue-classification-3",
    evaluation_strategy="epoch",
    learning_rate=L_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,
    weight_decay=WEIGHT_DECAY,
    save_total_limit=SAVE_TOTAL_LIM,
    num_train_epochs=NUM_EPOCHS,
    predict_with_generate=True,
    push_to_hub=True,
)


def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # Decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Convert sequences to strings or keep as-is if already strings
    decoded_preds = [
        str(pred) for pred in decoded_preds
    ]  # Convert to string if necessary
    decoded_labels = [
        str(label) for label in decoded_labels
    ]  # Convert to string if necessary

    # Generate classification report
    report = classification_report(decoded_labels, decoded_preds, digits=4)

    return {"classification_report": report}


data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset_classify["train"],
    eval_dataset=tokenized_dataset_classify["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

len(tokenized_dataset_classify["test"])


100

In [None]:
trainer.train()

trainer.push_to_hub()

model.to("cpu")

Epoch,Training Loss,Validation Loss,Classification Report
1,No log,0.517647,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 0.2500 0.2500 0.2500 4  CC 0.4286 0.7500 0.5455 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.5000 0.3333 0.4000 3  EXAM 0.3333 1.0000 0.5000 1  FAM/SOCHX 0.8696 0.9091 0.8889 22  GENHX 0.7083 0.8500 0.7727 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7143 0.7143 0.7143 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.6667 1.0000 0.8000 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.0000 0.0000 0.0000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.6900 100  macro avg 0.4173 0.4347 0.4004 100  weighted avg 0.6814 0.6900 0.6680 100
2,No log,0.518285,precision recall f1-score support  ALLERGY 1.0000 0.2500 0.4000 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.3750 0.7500 0.5000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.5000 0.3333 0.4000 3  EXAM 0.5000 1.0000 0.6667 1  FAM/SOCHX 0.8800 1.0000 0.9362 22  GENHX 0.7391 0.8500 0.7907 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7500 0.8571 0.8000 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.4444 1.0000 0.6154 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 1.0000 0.3333 0.5000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.7300 100  macro avg 0.4976 0.4630 0.4343 100  weighted avg 0.7341 0.7300 0.6967 100
3,0.033500,0.548374,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 0.0000 0.0000 0.0000 4  CC 0.5000 0.7500 0.6000 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.5000 0.5000 0.5000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.5000 1.0000 0.6667 1  FAM/SOCHX 0.8800 1.0000 0.9362 22  GENHX 0.6800 0.8500 0.7556 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7500 0.8571 0.8000 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.3750 0.7500 0.5000 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 0.0000 0.0000 0.0000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8571 0.5455 0.6667 11  accuracy 0.7000 100  macro avg 0.3966 0.4376 0.4016 100  weighted avg 0.6475 0.7000 0.6590 100
4,0.033500,0.588064,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.7500 0.7500 0.7500 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.5000 0.5000 0.5000 2  EDCOURSE 0.5000 0.3333 0.4000 3  EXAM 0.2000 1.0000 0.3333 1  FAM/SOCHX 0.8750 0.9545 0.9130 22  GENHX 0.7391 0.8500 0.7907 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 0.7143 0.7692 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.4000 1.0000 0.5714 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 0.0000 0.0000 0.0000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.7200 100  macro avg 0.4781 0.4744 0.4386 100  weighted avg 0.7290 0.7200 0.7001 100
5,0.033500,0.57188,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 1.0000 0.2500 0.4000 4  CC 0.6000 0.7500 0.6667 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 1.0000 0.5000 0.6667 2  EDCOURSE 0.5000 0.3333 0.4000 3  EXAM 0.5000 1.0000 0.6667 1  FAM/SOCHX 0.8462 1.0000 0.9167 22  GENHX 0.7727 0.8500 0.8095 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 0.7143 0.7692 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.4000 1.0000 0.5714 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 0.6667 0.6667 0.6667 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.7500 100  macro avg 0.5441 0.5100 0.4939 100  weighted avg 0.7564 0.7500 0.7280 100
6,0.016800,0.671591,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 0.0000 0.0000 0.0000 4  CC 0.2143 0.7500 0.3333 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.5000 0.5000 0.5000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.5000 1.0000 0.6667 1  FAM/SOCHX 0.9167 1.0000 0.9565 22  GENHX 0.8824 0.7500 0.8108 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.0000 0.0000 0.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.6250 0.7143 0.6667 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.3333 0.7500 0.4615 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 0.5000 0.3333 0.4000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.6800 100  macro avg 0.3618 0.3967 0.3570 100  weighted avg 0.6812 0.6800 0.6627 100
7,0.016800,0.59178,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 0.0000 0.0000 0.0000 4  CC 0.6000 0.7500 0.6667 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.5000 0.5000 0.5000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.3333 1.0000 0.5000 1  FAM/SOCHX 0.8462 1.0000 0.9167 22  GENHX 0.7727 0.8500 0.8095 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.5000 1.0000 0.6667 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.6250 0.7143 0.6667 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.5000 0.7500 0.6000 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 0.5000 0.3333 0.4000 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.7100 100  macro avg 0.3971 0.4517 0.4035 100  weighted avg 0.6691 0.7100 0.6776 100
8,0.016800,0.642595,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 0.0000 0.0000 0.0000 4  CC 0.3333 0.7500 0.4615 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.5000 1.0000 0.6667 1  FAM/SOCHX 0.9130 0.9545 0.9333 22  GENHX 0.7391 0.8500 0.7907 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.0000 0.0000 0.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.6250 0.7143 0.6667 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.4286 0.7500 0.5455 4  PASTSURGICAL 1.0000 1.0000 1.0000 8  PLAN 0.3333 0.3333 0.3333 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.6800 100  macro avg 0.3374 0.3744 0.3401 100  weighted avg 0.6542 0.6800 0.6548 100
9,0.009800,0.620411,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 0.0000 0.0000 0.0000 4  CC 0.6000 0.7500 0.6667 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 1.0000 0.5000 0.6667 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.3333 1.0000 0.5000 1  FAM/SOCHX 0.8462 1.0000 0.9167 22  GENHX 0.8000 0.8000 0.8000 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.6667 0.8571 0.7500 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.4444 1.0000 0.6154 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 0.6667 0.6667 0.6667 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.7300 100  macro avg 0.4561 0.4855 0.4463 100  weighted avg 0.6953 0.7300 0.6968 100
10,0.009800,0.608587,precision recall f1-score support  ALLERGY 1.0000 0.5000 0.6667 4  ASSESSMENT 0.3333 0.2500 0.2857 4  CC 0.6000 0.7500 0.6667 4  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 2  EDCOURSE 0.0000 0.0000 0.0000 3  EXAM 0.5000 1.0000 0.6667 1  FAM/SOCHX 0.8750 0.9545 0.9130 22  GENHX 0.7273 0.8000 0.7619 20  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 0.7143 0.7692 7 OTHER_HISTORY 0.0000 0.0000 0.0000 1 PASTMEDICALHX 0.4000 1.0000 0.5714 4  PASTSURGICAL 0.8889 1.0000 0.9412 8  PLAN 0.3333 0.3333 0.3333 3  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8750 0.6364 0.7368 11  accuracy 0.7000 100  macro avg 0.4183 0.4469 0.4156 100  weighted avg 0.6820 0.7000 0.6777 100


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Trainer is attempting to log a value of "               precision    recall  f1-score   support

      ALLERGY     1.0000    0.2500    0.4000         4
   ASSESSMENT     0.2500    0.2500    0.2500         4
           CC     0.4286    0.7500    0.5455         4
    DIAGNOSIS     0.0000    0.0000    0.0000         1
  DISPOSITION     0.0000    0.0000    0.0000         2
     EDCOURSE     0.5000    0.3333    0.4000         3
         EXAM     0.3333    1.0000    0.5000         1
    FAM/SOCHX     0.8696    0.9091    0.8889        22
        GENHX     0.7083    0.8500    0.7727        20
        GYNHX     0.0000    0.0000    0.0000         1
      IMAGING     0.0000    0.0000    0.0000         1
IMMUNIZATIONS     1.0000    1.0000    1.0000         1
         LABS     0.0000    0.0000    0.0000         1
  MEDICATIONS     0.714

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

events.out.tfevents.1704928221.cc5a0990e75c.1451.2:   0%|          | 0.00/12.0k [00:00<?, ?B/s]

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (wo): 

In [None]:
# test
inputs = [
    "classify this dialogue: Doctor: So are you taking any medications at the moment?\nPatient: yes, I take tirosint every morning.\nDoctor: what dosage?\nPatient:125mg."
]
# inputs = ['classify this dialogue: Doctor: \nPatient: \nDoctor: ']

inputs = tokenizer(inputs, return_tensors="pt")
outputs = model.generate(**inputs)
category = tokenizer.decode(outputs[0])
category



'<pad> MEDICATIONS</s>'