In [None]:
""" Classification Training with Flan-T5-Base on MTS Dialogue Dataset 
    https://huggingface.co/sarahahatee/t5-dialogue-classification-3"""

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install datasets
!pip install evaluate
!pip install transformers
!pip install numpy
!pip install tensorflow
!pip install -U accelerate
!pip install transformers[torch]

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dill, multiprocess, datasets
Successfully installed datasets-2.16.1 dill-0.3.7 multiprocess-0.70.15
Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19 (from e

In [3]:
import accelerate
from datasets import load_dataset, load_metric, concatenate_datasets, DatasetDict
from transformers import (
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
import numpy as np
from sklearn.metrics import classification_report
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
# Load training datasets
train_file_a = "data/MTS-Dialog-TrainingSet.csv"
# train_file_b = "data/MTS-Dialog-ValidationSet.csv"

# Load testing dataset
# test_file = "data/MTS-Dialog-TestSet-1-MEDIQA-Chat-2023.csv"
test_file = "data/MTS-Dialog-ValidationSet.csv"

# Load training datasets from CSV files
# train_dataset_a = load_dataset('csv', data_files=train_file_a)['train']
# train_dataset_b = load_dataset('csv', data_files=train_file_b)['train']
combined_train_dataset= load_dataset('csv', data_files=train_file_a)['train']

# Load testing dataset from CSV file
test_dataset = load_dataset('csv', data_files=test_file)['train']

# Combine the two training datasets
# combined_train_dataset = concatenate_datasets([train_dataset_a, train_dataset_b])

# Create dialog_dataset dictionary
dialog_dataset = DatasetDict({
    "train": combined_train_dataset,
    "test": test_dataset
})

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [10]:
# loading flan-t5 instead of plain t5
checkpoint = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = checkpoint

# We prefix our tasks with "classify"
PREFIX_CLASSIFY = "Classify the topic of this dialogue: "


# Define the preprocessing function
def preprocess_function_classify(examples):
    """Add prefix to the sentences, tokenize the text, and set the labels"""
    # The "inputs" are the tokenized answer:
    inputs = [PREFIX_CLASSIFY + doc for doc in examples["dialogue"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)

    # The "labels" are the tokenized outputs:
    labels = tokenizer(
        text_target=examples["section_header"], max_length=50, truncation=True
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


# Map the preprocessing function across our dataset
tokenized_dataset_classify = dialog_dataset.map(
    preprocess_function_classify, batched=True
)

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Map:   0%|          | 0/1201 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [11]:
# Global Parameters
L_RATE = 1e-4
BATCH_SIZE = 7
PER_DEVICE_EVAL_BATCH = 7
WEIGHT_DECAY = 0.01
SAVE_TOTAL_LIM = 3
NUM_EPOCHS = 20

# Set up training arguments
training_args = Seq2SeqTrainingArguments(
    "t5-dialogue-classification-3",
    evaluation_strategy="epoch",
    learning_rate=L_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,
    weight_decay=WEIGHT_DECAY,
    save_total_limit=SAVE_TOTAL_LIM,
    num_train_epochs=NUM_EPOCHS,
    predict_with_generate=True,
    push_to_hub=True,
)


def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # Decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Convert sequences to strings or keep as-is if already strings
    decoded_preds = [
        str(pred) for pred in decoded_preds
    ]  # Convert to string if necessary
    decoded_labels = [
        str(label) for label in decoded_labels
    ]  # Convert to string if necessary

    # Generate classification report
    report = classification_report(decoded_labels, decoded_preds, digits=4)

    return {"classification_report": report}


data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset_classify["train"],
    eval_dataset=tokenized_dataset_classify["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

len(tokenized_dataset_classify["test"])


200

In [12]:
trainer.train()

trainer.push_to_hub()

model.to("cpu")

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Classification Report
1,No log,0.322163,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.4000 0.1818 0.2500 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.0000 0.0000 0.0000 5  FAM/SOCHX 0.8302 0.9778 0.8980 45  GENHX 0.5844 0.8491 0.6923 53  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.0000 0.0000 0.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 1.0000 0.9091 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.4000 0.7143 0.5128 14  PASTSURGICAL 0.8750 1.0000 0.9333 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.6667 0.3529 0.4615 17  accuracy 0.6750 200  macro avg 0.2795 0.2996 0.2807 200  weighted avg 0.5806 0.6750 0.6099 200
2,No log,0.235559,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.3333 0.5455 0.4138 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.2000 1.0000 0.3333 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 1.0000 0.2000 0.3333 5  FAM/SOCHX 0.8936 0.9333 0.9130 45  GENHX 0.7377 0.8491 0.7895 53  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.0000 0.0000 0.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7692 1.0000 0.8696 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.5217 0.8571 0.6486 14  PASTSURGICAL 1.0000 0.8571 0.9231 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8462 0.6471 0.7333 17  accuracy 0.7250 200  macro avg 0.3651 0.3903 0.3457 200  weighted avg 0.6828 0.7250 0.6883 200
3,0.371400,0.299359,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.4000 0.3636 0.3810 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.1818 0.4000 0.2500 5  FAM/SOCHX 0.8958 0.9556 0.9247 45  GENHX 0.7818 0.8113 0.7963 53  GYNHX 1.0000 1.0000 1.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7692 1.0000 0.8696 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.5217 0.8571 0.6486 14  PASTSURGICAL 1.0000 1.0000 1.0000 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.7647 0.7647 0.7647 17  accuracy 0.7350 200  macro avg 0.4158 0.4535 0.4296 200  weighted avg 0.6803 0.7350 0.7026 200
4,0.371400,0.303622,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.4167 0.4545 0.4348 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.5000 1.0000 0.6667 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.2857 0.4000 0.3333 5  FAM/SOCHX 0.8958 0.9556 0.9247 45  GENHX 0.7667 0.8679 0.8142 53  GYNHX 1.0000 1.0000 1.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.7692 1.0000 0.8696 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.5238 0.7857 0.6286 14  PASTSURGICAL 0.8750 1.0000 0.9333 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8462 0.6471 0.7333 17  accuracy 0.7450 200  macro avg 0.4440 0.5014 0.4647 200  weighted avg 0.6850 0.7450 0.7093 200
5,0.371400,0.336601,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.3333 0.4545 0.3846 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.2000 1.0000 0.3333 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.3333 0.4000 0.3636 5  FAM/SOCHX 0.8800 0.9778 0.9263 45  GENHX 0.8077 0.7925 0.8000 53  GYNHX 1.0000 1.0000 1.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 1.0000 0.9091 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.5714 0.8571 0.6857 14  PASTSURGICAL 1.0000 0.7143 0.8333 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.7222 0.7647 0.7429 17  accuracy 0.7350 200  macro avg 0.4341 0.4939 0.4468 200  weighted avg 0.6878 0.7350 0.7055 200
6,0.084300,0.38521,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.4000 0.5455 0.4615 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.3333 1.0000 0.5000 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.4000 0.4000 0.4000 5  FAM/SOCHX 0.8936 0.9333 0.9130 45  GENHX 0.8723 0.7736 0.8200 53  GYNHX 1.0000 1.0000 1.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.5000 1.0000 0.6667 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 1.0000 0.9091 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.5238 0.7857 0.6286 14  PASTSURGICAL 0.8750 1.0000 0.9333 7  PLAN 0.2500 1.0000 0.4000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.7059 0.7059 0.7059 17  accuracy 0.7300 200  macro avg 0.4294 0.5530 0.4647 200  weighted avg 0.7036 0.7300 0.7105 200
7,0.084300,0.463219,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.7778 0.6364 0.7000 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.0000 0.0000 0.0000 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.4000 0.4000 0.4000 5  FAM/SOCHX 0.8936 0.9333 0.9130 45  GENHX 0.8696 0.7547 0.8081 53  GYNHX 0.0000 0.0000 0.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.5000 1.0000 0.6667 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 1.0000 0.9091 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.4286 0.8571 0.5714 14  PASTSURGICAL 0.8750 1.0000 0.9333 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.6842 0.7647 0.7222 17  accuracy 0.7250 200  macro avg 0.3631 0.4131 0.3790 200  weighted avg 0.7072 0.7250 0.7083 200
8,0.084300,0.42868,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.5000 0.4545 0.4762 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.2500 1.0000 0.4000 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.4000 0.4000 0.4000 5  FAM/SOCHX 0.9130 0.9333 0.9231 45  GENHX 0.7759 0.8491 0.8108 53  GYNHX 1.0000 1.0000 1.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.5000 1.0000 0.6667 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 1.0000 0.9091 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.5714 0.8571 0.6857 14  PASTSURGICAL 1.0000 1.0000 1.0000 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8000 0.7059 0.7500 17  accuracy 0.7450 200  macro avg 0.4272 0.5058 0.4489 200  weighted avg 0.7020 0.7450 0.7187 200
9,0.030900,0.486309,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.3571 0.4545 0.4000 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.2500 1.0000 0.4000 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.3333 0.4000 0.3636 5  FAM/SOCHX 0.8936 0.9333 0.9130 45  GENHX 0.8182 0.8491 0.8333 53  GYNHX 1.0000 1.0000 1.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 1.0000 1.0000 1.0000 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 1.0000 0.9091 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.5500 0.7857 0.6471 14  PASTSURGICAL 0.8750 1.0000 0.9333 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8571 0.7059 0.7742 17  accuracy 0.7400 200  macro avg 0.4384 0.5023 0.4565 200  weighted avg 0.7008 0.7400 0.7160 200
10,0.030900,0.519099,precision recall f1-score support  ALLERGY 1.0000 0.9167 0.9565 12  ASSESSMENT 0.0000 0.0000 0.0000 11  CC 0.4286 0.5455 0.4800 11  DIAGNOSIS 0.0000 0.0000 0.0000 1  DISPOSITION 0.2500 1.0000 0.4000 1  EDCOURSE 0.0000 0.0000 0.0000 4  EXAM 0.4000 0.4000 0.4000 5  FAM/SOCHX 0.8936 0.9333 0.9130 45  GENHX 0.8070 0.8679 0.8364 53  GYNHX 1.0000 1.0000 1.0000 1  IMAGING 0.0000 0.0000 0.0000 1 IMMUNIZATIONS 0.5000 1.0000 0.6667 1  LABS 0.0000 0.0000 0.0000 1  MEDICATIONS 0.8333 1.0000 0.9091 10 OTHER_HISTORY 0.0000 0.0000 0.0000 3 PASTMEDICALHX 0.6111 0.7857 0.6875 14  PASTSURGICAL 1.0000 1.0000 1.0000 7  PLAN 0.0000 0.0000 0.0000 1  PROCEDURES 0.0000 0.0000 0.0000 1  ROS 0.8000 0.7059 0.7500 17  accuracy 0.7500 200  macro avg 0.4262 0.5077 0.4500 200  weighted avg 0.7047 0.7500 0.7235 200


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Trainer is attempting to log a value of "               precision    recall  f1-score   support

      ALLERGY     1.0000    0.9167    0.9565        12
   ASSESSMENT     0.0000    0.0000    0.0000        11
           CC     0.4000    0.1818    0.2500        11
    DIAGNOSIS     0.0000    0.0000    0.0000         1
  DISPOSITION     0.0000    0.0000    0.0000         1
     EDCOURSE     0.0000    0.0000    0.0000         4
         EXAM     0.0000    0.0000    0.0000         5
    FAM/SOCHX     0.8302    0.9778    0.8980        45
        GENHX     0.5844    0.8491    0.6923        53
        GYNHX     0.0000    0.0000    0.0000         1
      IMAGING     0.0000    0.0000    0.0000         1
IMMUNIZATIONS     0.0000    0.0000    0.0000         1
         LABS     0.0000    0.0000    0.0000         1
  MEDICATIONS     0.833

events.out.tfevents.1705169676.c51df5b7c6ee.1498.1:   0%|          | 0.00/12.0k [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

In [13]:
# test
inputs = [
    "classify this dialogue: Doctor: So are you taking any medications at the moment?\nPatient: yes, I take tirosint every morning.\nDoctor: what dosage?\nPatient:125mg."
]
# inputs = ['classify this dialogue: Doctor: \nPatient: \nDoctor: ']

inputs = tokenizer(inputs, return_tensors="pt")
outputs = model.generate(**inputs)
category = tokenizer.decode(outputs[0])
category



'<pad> MEDICATIONS</s>'