In [None]:
import numpy as np
import os
import pandas as pd
import random
import json
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, LongformerForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from sklearn.metrics import precision_score, recall_score, f1_score
import torch
from torchmetrics.classification import MultilabelF1Score

model_id = 'yikuan8/Clinical-Longformer'

In [None]:
dataset_path = "./MIMIC/filterd_data/UQ_training"
csv_train = 'training_data_diagnosis.csv'
csv_eval = 'val_data_diagnosis.csv'
csv_test = 'testing_data_diagnosis.csv'

raw_datasets = load_dataset('csv', data_files={'train': os.path.join(dataset_path, csv_train),
                                               'eval': os.path.join(dataset_path, csv_eval),
                                               'test': os.path.join(dataset_path, csv_test)})


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

def tokenize_func(data):
  return tokenizer(data['text'], truncation=True)

def adjust_and_tokenize_datasets(raw_datasets):
    for key in raw_datasets.keys():
        all_labels = []
        columns = raw_datasets[key].column_names
        for row in raw_datasets[key]:
            labels = [row[column] for column in columns if (column != 'hadm_id' and column != 'text')]
            all_labels.append(labels)
        print(len(labels),labels[0])
        raw_datasets[key] = raw_datasets[key].add_column('labels', all_labels)
    columns = raw_datasets["train"].column_names
    columns.remove("labels")
    tokenized_datasets = raw_datasets.map(tokenize_func, batched=True, remove_columns=columns)
    tokenized_datasets.set_format("torch")
    tokenized_datasets = (tokenized_datasets
              .map(lambda x : {"float_labels": x["labels"].to(torch.float)}, remove_columns=["labels"])
              .rename_column("float_labels", "labels"))
    return tokenized_datasets

def compute_metrics(eval_preds):
    metric = MultilabelF1Score(num_labels=25)
    logits, labels = eval_preds
    preds = torch.sigmoid(torch.tensor(logits))
    target = torch.tensor(labels, dtype=torch.int8)
    f1_macro = metric(preds,target)
    return {'f1_macro':f1_macro}


In [None]:
# tokenized_datasets = adjust_and_tokenize_datasets(raw_datasets)
# tokenized_datasets.save_to_disk("./MIMIC/filterd_data/UQ_training/tokenized_datasets_diagnosis")

In [None]:
tokenized_datasets = load_from_disk("./MIMIC/filterd_data/UQ_training/tokenized_datasets_diagnosis")
train_dataset = tokenized_datasets["train"]
train_dataset = train_dataset.shuffle(seed=27)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest')

model = LongformerForSequenceClassification.from_pretrained(model_id, num_labels=25)
model.cuda()

batch_size = 8
lr = 2e-4 #2e-5 2e-4 5e-5 5e-4
#warmup_step =
eps = 1e-6
# weight_decay =
OUTPUT_PATH = f"./model/UQ_diagnosis/{lr}_{eps}"

In [None]:
training_args = TrainingArguments(OUTPUT_PATH,
                                  save_strategy="steps",
                                  save_steps = 10,
                                  save_total_limit = 10,
                                  load_best_model_at_end=True,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  gradient_accumulation_steps=160,
                                  evaluation_strategy="steps",
                                  eval_steps = 10,
                                  num_train_epochs=2,
                                  adam_epsilon = eps,
                                  learning_rate=lr,
                                  logging_steps=10,
                                  logging_strategy = 'steps',
                                  fp16=True,
                                  )
trainer = Trainer(
    model,
    training_args,
    train_dataset=train_dataset,
    eval_dataset=tokenized_datasets['eval'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train(resume_from_checkpoint=True)

In [None]:
trainer = Trainer(
    model,
    tokenizer=tokenizer,
)
predictions = trainer.predict(tokenized_datasets['test'])