In [None]:
import os
import pandas as pd
import random
import json
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, LongformerForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from sklearn.metrics import precision_score, recall_score, f1_score
import torch

model_id = 'yikuan8/Clinical-Longformer'

In [None]:
# data proprecessing for fine-tuning on readmmision prediction task

dataset_path = "./MIMIC/filterd_data/UQ_training"
csv_train = 'training_data_readmission.csv'
csv_eval = 'val_data_readmission.csv'
csv_test = 'testing_data_readmission.csv'

raw_datasets = load_dataset('csv', data_files={'train': os.path.join(dataset_path, csv_train),
                                               'eval': os.path.join(dataset_path, csv_eval),
                                               'test': os.path.join(dataset_path, csv_test)})

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

def tokenize_func(data):
  return tokenizer(data['text'], truncation=True)

In [None]:
for key in raw_datasets.keys():
  raw_datasets[key] = raw_datasets[key].rename_column("label", "labels")

In [None]:
tokenized_datasets = raw_datasets.map(tokenize_func, batched=True, remove_columns='text')
tokenized_datasets.set_format("torch")

In [None]:
train_dataset = tokenized_datasets["train"]
train_dataset = train_dataset.shuffle(seed=27)

In [None]:
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=1)
    #labels = np.argmax(labels, axis=1)

    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    f1 = f1_score(labels, preds, average='weighted')

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1
    }



In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest')

model = LongformerForSequenceClassification.from_pretrained(model_id, num_labels=2)
model.cuda()

batch_size = 8
lr = 2e-4 #2e-5 2e-4 5e-5 5e-4
#warmup_step =
eps = 1e-6
# weight_decay =
OUTPUT_PATH = f"./model/UQ_model/readmission_{lr}_{eps}"

In [None]:
training_args = TrainingArguments(OUTPUT_PATH,
                                  save_strategy="steps",
                                  save_steps = 10,
                                  save_total_limit = 1,
                                  load_best_model_at_end=True,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  gradient_accumulation_steps=32,
                                  evaluation_strategy="steps",
                                  eval_steps = 10,
                                  num_train_epochs=2,
                                  adam_epsilon = eps,
                                  learning_rate=lr,
                                  logging_steps=10,
                                  logging_strategy = 'steps',
                                  fp16=True,
                                  )
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['eval'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()
# trainer.train(resume_from_checkpoint='./UQ_model/readmission/checkpoint-140')

In [None]:
trainer.save()

In [None]:
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest')

# model2 = LongformerForSequenceClassification.from_pretrained('./model/UQ_model/readmission_5e-5/checkpoint-140', use_safetensors=True)


# trainer = Trainer(
#     model2,
#     tokenizer=tokenizer,
#     data_collator=data_collator,
# )

# pre = trainer.predict(tokenized_datasets['test'])
# preds = np.argmax(pre.predictions, axis=1)
# labels = pre.label_ids

In [None]:
precision = precision_score(labels, preds)
recall = recall_score(labels, preds)
f1 = f1_score(labels, preds)

