In [None]:
%pip install numpy
%pip install tqdm
%pip install datasets
%pip install transformers
%pip install islab-opendeid
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, GPTNeoXTokenizerFast, get_linear_schedule_with_warmup
from islab.aicup import OpenDeidBatchSampler, collate_batch_with_prompt_template
from datasets import load_dataset, concatenate_datasets, Features, Value
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.optim import AdamW
from tqdm import tqdm, trange
import torch.optim as optim
import numpy as np
import random
import torch
import os

In [None]:
MODEL_SAVEDIR = r"D:\Trained_Model"
TRAINING_DATASET_PATH = r"D:\Model_Training_Datasets"
MODEL_FILENAME = "70m_10epoch.pt"
LANGUAGE_MODEL = "EleutherAI/pythia-70m" #Language model, default: EleutherAI/pythia-70m
DATALOADER_BATCH_SIZE = 8 #default: 8, Higher = More VRAM usage, GPUs with < 8GB VRAM may just use the default value.
TORCH_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") #CUDA first, mps second, cpu last.
FORCE_CUDA_CLEAR_MEMORY = False #default: False, forces CUDA clear memory as every epoch done, enabling it may increase training time.
EPOCHS = 10 #default: 10, More epochs = more training time.

In [None]:
# Set seed 0
def set_torch_seed(seed = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benckmark = False #default: false
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

# Read file, return: string array
def read_file(path):
    with open(path , 'r' , encoding = 'utf-8-sig') as fr:
        return fr.readlines()

set_torch_seed()

# Check the existence of model save directory
if not os.path.isdir(MODEL_SAVEDIR):
    os.mkdir(MODEL_SAVEDIR)

### 資料處理

In [None]:
# Set tokens
bos, eos, pad, sep = '<|endoftext|>', '<|END|>', '<|pad|>', '\n\n####\n\n'
special_tokens_dict = {'eos_token': eos, 'bos_token': bos, 'pad_token': pad, 'sep_token': sep}

# Process answer.txt, out: annotation dicitonary
def process_annotation_file(lines):
    print("process annotation file...")
    entity_dict = {}
    for line in lines:
        items = line.strip('\n').split('\t')
        if len(items) == 5:
            item_dict = {
                'phi' : items[1],
                'st_idx' : int(items[2]),
                'ed_idx' : int(items[3]),
                'entity' : items[4],
            }
        elif len(items) == 6:
            item_dict = {
                'phi' : items[1],
                'st_idx' : int(items[2]),
                'ed_idx' : int(items[3]),
                'entity' : items[4],
                'normalize_time' : items[5],
            }
        if items[0] not in entity_dict:
            entity_dict[items[0]] = [item_dict]
        else:
            entity_dict[items[0]].append(item_dict)
    print("annotation file done, {}", entity_dict)
    return entity_dict

# Process single report, out: seq_pairs
def process_medical_report(txt_name, medical_report_folder, annos_dict, special_tokens_dict):
    file_name = txt_name + '.txt'
    sents = read_file(os.path.join(medical_report_folder, file_name))
    article = "".join(sents)

    bounary, item_idx, temp_seq, seq_pairs = 0, 0, "", []
    new_line_idx = 0
    for w_idx, word in enumerate(article):
        if word == '\n':
            new_line_idx = w_idx + 1
            if article[bounary:new_line_idx] == '\n':
                continue
            if temp_seq == "":
                temp_seq = "PHI:Null"
            sentence = article[bounary:new_line_idx].strip().replace('\t', ' ')
            temp_seq = temp_seq.strip('\\n')
            seq_pair = f"{txt_name}\t{new_line_idx}\t{sentence}\t{temp_seq}\n"
            # seq_pair = special_tokens_dict['bos_token'] + article[bounary:new_line_idx] + special_tokens_dict['sep_token'] + temp_seq + special_tokens_dict['eos_token']
            bounary = new_line_idx
            seq_pairs.append(seq_pair)
            temp_seq = ""
        if w_idx == annos_dict[txt_name][item_idx]['st_idx']:
            phi_key = annos_dict[txt_name][item_idx]['phi']
            phi_value = annos_dict[txt_name][item_idx]['entity']
            if 'normalize_time' in annos_dict[txt_name][item_idx]:
                temp_seq += f"{phi_key}:{phi_value}=>{annos_dict[txt_name][item_idx]['normalize_time']}\\n"
            else:
                temp_seq += f"{phi_key}:{phi_value}\\n"
            if item_idx == len(annos_dict[txt_name]) - 1:
                continue
            item_idx += 1
    return seq_pairs

# Generate .tsv file
def generate_annotated_medical_report_parallel(anno_file_path, medical_report_folder, tsv_output_path):
    anno_lines = read_file(anno_file_path)
    annos_dict = process_annotation_file(anno_lines)
    txt_names = list(annos_dict.keys())

    print("processing each medical file")

    all_seq_pairs = []
    for txt_name in txt_names:
        all_seq_pairs.extend(process_medical_report(txt_name, medical_report_folder, annos_dict, special_tokens_dict))
    print(all_seq_pairs[:10])
    print("All medical file done")
    print("write out to tsv format...")
    with open(tsv_output_path, 'w', encoding = 'utf-8') as fw:
        for seq_pair in all_seq_pairs:
            fw.write(seq_pair)
    print("tsv format dataset done")
    # return all_seq_pairs

# Find specified folder and answer.txt
def process_folders(main_folder_path):
    for folder_name in os.listdir(main_folder_path):
        folder_path = os.path.join(main_folder_path, folder_name)
        if os.path.isdir(folder_path):
            # Find annotation file (answer.txt)
            annotation_files = [file for file in os.listdir(folder_path) if file.startswith('answer_') and file.endswith('.txt')]
            if len(annotation_files) == 1:
                anno_info_path = os.path.join(folder_path, annotation_files[0])
            else:
                print(f"Error: Found {len(annotation_files)} annotation files in {folder_name}. Expected 1.")
                continue

            # Find report folder
            report_folders = [subfolder for subfolder in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, subfolder)) and subfolder.startswith('MedicalReport_')]
            if len(report_folders) == 1:
                report_folder = os.path.join(folder_path, report_folders[0])
            else:
                print(f"Error: Found {len(report_folders)} report folders in {folder_name}. Expected 1.")
                continue

            tsv_output_path = f'./{folder_name}_train.tsv'
            generate_annotated_medical_report_parallel(anno_info_path, report_folder, tsv_output_path)

if __name__ == "__main__":
    process_folders(TRAINING_DATASET_PATH)

### Read Tsv Dataset

In [None]:
# List of file names
tsv_file_names = ["1_train.tsv", "2_train.tsv", "3_train.tsv"]

# Load each dataset
datasets = [
    load_dataset(
        "csv",
        data_files = file_name,
        delimiter = '\t',
        features = Features({
            'fid': Value('string'),
            'idx': Value('int64'),
            'content': Value('string'),
            'label': Value('string')
        }),
        column_names = ['fid', 'idx', 'content', 'label'],
        split = 'train',  # Specify the split explicitly
        keep_default_na = False
    ) for file_name in tsv_file_names
]

# Concatenate datasets into a single dataset
merged_dataset = concatenate_datasets(datasets)

print(f"Total datasets: {len(merged_dataset)}, first data is {merged_dataset[0]}")

In [None]:
# Initialize tokenizer
tokenizer = GPTNeoXTokenizerFast.from_pretrained(LANGUAGE_MODEL, revision = "step3000")
tokenizer.padding_side = 'left'
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print(f"{tokenizer.pad_token}: {tokenizer.pad_token_id}")

In [None]:
# Load dataset
train_data = list(merged_dataset)
bucket_train_dataloader = DataLoader(train_data,
                                    batch_sampler = OpenDeidBatchSampler(train_data, DATALOADER_BATCH_SIZE),
                                    collate_fn = lambda batch: collate_batch_with_prompt_template(batch, tokenizer),
                                    pin_memory = True)

In [None]:
# Test tokenizer
results = tokenizer([
    f"{bos} 9364819.RAN\\nMINTANIA, JEFFRY {sep} ID: 9364819.RAN\\nNAME: MINTANIA, JEFFRY {eos}",
    f"{bos} This is a sentence {sep} PHI: NULL {eos}"
    ], padding = True)

for i in range(len(results)):
    print(results['attention_mask'][i])
    print(tokenizer.decode(results['input_ids'][i]))

In [None]:
bucket_train_dataloader = DataLoader(train_data,
                                    batch_sampler = OpenDeidBatchSampler(train_data, DATALOADER_BATCH_SIZE),
                                    collate_fn = lambda batch: collate_batch_with_prompt_template(batch, tokenizer),
                                    pin_memory = True)

In [None]:
# Initialize config and Language model
config = GPTNeoXConfig.from_pretrained(LANGUAGE_MODEL,
                                    bos_token_id = tokenizer.bos_token_id,
                                    eos_token_id = tokenizer.eos_token_id,
                                    pad_token_id = tokenizer.pad_token_id,
                                    sep_token_id = tokenizer.sep_token_id,
                                    output_hidden_states = False)

model = GPTNeoXForCausalLM.from_pretrained(LANGUAGE_MODEL, revision = "step3000", config = config)

In [None]:
optimizer = AdamW(model.parameters(), lr = 3e-5) # lr default: 3e-5
model.resize_token_embeddings(len(tokenizer))
model.to(TORCH_DEVICE)

In [None]:
# Model training
min_loss = 9999
global_step = 0
total_loss = 0

for _ in trange(EPOCHS, desc = "Epoch"):
    model.train()
    total_loss = 0

    # Clear memory
    if FORCE_CUDA_CLEAR_MEMORY:
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    
    # Training loop
    for step, (seqs, labels, masks) in enumerate(bucket_train_dataloader):
        seqs = seqs.to(TORCH_DEVICE)
        labels = labels.to(TORCH_DEVICE)
        masks = masks.to(TORCH_DEVICE)
        model.zero_grad()
        outputs = model(seqs, labels = labels, attention_mask = masks)
        logits = outputs.logits
        loss = outputs.loss.mean() #Combined from: outputs.loss & loss.mean()
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    avg_train_loss = total_loss / len(bucket_train_dataloader)
    print("Average train loss: {}".format(avg_train_loss))
    torch.save(model.state_dict(), os.path.join(MODEL_SAVEDIR, 'Trained_Finial.pt'))
    if avg_train_loss < min_loss:
        min_loss = avg_train_loss
        torch.save(model.state_dict(), os.path.join(MODEL_SAVEDIR, MODEL_FILENAME))