In [None]:
%pip install datasets
%pip install numpy
%pip install tqdm
%pip install transformers
%pip install tqdm
%pip install datasets
%pip install accelerate
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

### import package

In [None]:
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast, GPTNeoXConfig, get_linear_schedule_with_warmup
from datasets import load_dataset, Features, Value
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm, trange
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.optim import AdamW
import matplotlib.pyplot as plt
import torch.optim as optim
import numpy as np
import random
import torch
import os
import re
import io

In [None]:
MODEL_DIR = r"D:\Trained_Model"
MODEL_FILENAME = "70m_20epoch.pt"
SAVE_LOCATION_WITH_FILENAME = r"D:\answer_dl.txt"
TEST_DATASET_PATH = r"D:\TestDatasets\opendid_test"
VALIDATE_OUT_PATH = "./valid.tsv"
LANGUAGE_MODEL = "EleutherAI/pythia-70m" #Language model, default: EleutherAI/pythia-70m
TORCH_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") #CUDA first, mps second, cpu last.
PREDICT_BATCH_SIZE = 64 #default: 32, Higher = More VRAM usage, GPUs with < 8GB VRAM may just use the default value.

In [None]:
# Set seed 0
def set_torch_seed(seed = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benckmark = False
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

# Read file, return: string array
def read_file(path):
    with open(path, 'r', encoding = 'utf-8-sig') as fr:
        return fr.readlines()

set_torch_seed()

### Dataloader Sample

In [None]:
# Set tokens
bos, eos, pad, sep = '<|endoftext|>', '<|END|>', '<|pad|>', '\n\n####\n\n'
special_tokens_dict = {'eos_token': eos, 'bos_token': bos, 'pad_token': pad, 'sep_token': sep}

tokenizer = GPTNeoXTokenizerFast.from_pretrained(LANGUAGE_MODEL, revision = "step3000")
tokenizer.padding_side = "left"
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print(f"{tokenizer.pad_token}: {tokenizer.pad_token_id}")

In [None]:
# the model config to which we add the special tokens
config = GPTNeoXConfig.from_pretrained(LANGUAGE_MODEL,
                                        bos_token_id = tokenizer.bos_token_id,
                                        eos_token_id = tokenizer.eos_token_id,
                                        pad_token_id = tokenizer.pad_token_id,
                                        sep_token_id = tokenizer.sep_token_id,
                                        output_hidden_states = False)

model = GPTNeoXForCausalLM.from_pretrained(LANGUAGE_MODEL, revision = "step3000", config = config)

In [None]:
optimizer = AdamW(model.parameters(), lr = 3e-5) # YOU CAN ADJUST LEARNING RATE
model.resize_token_embeddings(len(tokenizer))
model.to(TORCH_DEVICE)

In [None]:
model.load_state_dict(torch.load(os.path.join(MODEL_DIR, MODEL_FILENAME)))
model = model.to(TORCH_DEVICE)

def sample_text(model, tokenizer, text, n_words = 20):
    model.eval()
    text = tokenizer.encode(text)
    inputs, past_key_values = torch.tensor([text]).to(TORCH_DEVICE), None

    with torch.no_grad():
        for _ in range(n_words):
            out = model(inputs, past_key_values = past_key_values)
            logits = out.logits
            past_key_values = out.past_key_values
            log_probs = F.softmax(logits[:, -1], dim = -1)
            inputs = torch.multinomial(log_probs, 1)
            text.append(inputs.item())
            if tokenizer.decode(inputs.item()) == eos:
                break

    return tokenizer.decode(text)

# Test tokenizer
text_date = special_tokens_dict['bos_token'] + "D.O.B:  29/9/2000" + special_tokens_dict['sep_token']
text_time = special_tokens_dict['bos_token'] + "Collected: 09/07/2012 at 09:14" + special_tokens_dict['sep_token']
print(sample_text(model, tokenizer, text = text_date, n_words = 20))
print(sample_text(model, tokenizer, text = text_time, n_words = 20))

In [None]:
def process_valid_data(test_txts , out_file):
    with open(out_file, 'w', encoding = 'utf-8') as fw:
        for txt in test_txts:
            m_report = read_file(txt)
            boundary = 0
            # temp = ''.join(m_report)
            fid = os.path.splitext(os.path.basename(txt))[0] #split "/" way doesn't work in windows
            for idx,sent in enumerate(m_report):
                if sent.replace(' ', '').replace('\n', '').replace('\t', '') != '':
                    sent = sent.replace('\t' , ' ')
                    fw.write(f"{fid}\t{boundary}\t{sent}\n")
                # else:
                #     print(f"{fid}\t{boundary}\t{sent}\n")
                #     assert 1==2
                boundary += len(sent)

test_txts = list(map(lambda x:os.path.join(TEST_DATASET_PATH, x), os.listdir(TEST_DATASET_PATH)))
test_txts = sorted(test_txts)
valid_data = process_valid_data(test_txts, VALIDATE_OUT_PATH)

In [None]:
valid_data = load_dataset("csv", data_files = VALIDATE_OUT_PATH, delimiter = '\t',
                            features = Features({
                                'fid': Value('string'), 'idx': Value('int64'),
                                'content': Value('string'), 'label': Value('string')}),
                            column_names = ['fid', 'idx', 'content', 'label'])

valid_list = list(valid_data['train'])

In [None]:
train_phi_category = ['PATIENT', 'DOCTOR', 'USERNAME', 'PROFESSION', 'ROOM', 'DEPARTMENT', 
                    'HOSPITAL', 'ORGANIZATION', 'STREET', 'CITY', 'STATE', 'COUNTRY', 
                    'ZIP', 'LOCATION-OTHER', 'AGE', 'DATE', 'TIME', 'DURATION', 
                    'SET', 'PHONE', 'FAX', 'EMAIL', 'URL', 'IPADDR',
                    'SSN', 'MEDICALRECORD', 'HEALTHPLAN', 'ACCOUNT', 'LICENSE', 'VEHICLE', 
                    'DEVICE', 'BIOID', 'IDNUM']

def get_anno_format(sentence, infos, boundary):
    anno_list = []
    lines = infos.split("\n")
    normalize_keys = ["DATE", "TIME", "DURATION", "SET"]
    phi_dict = {}
    for line in lines:
        parts = line.split(":")
        if parts[0] not in train_phi_category or parts[1] == "":
            continue
        if len(parts) == 2:
            phi_dict[parts[0]] = parts[1].strip()
    for phi_key, phi_value in phi_dict.items():
        normalize_time = None
        if phi_key in normalize_keys:
            if '=>' in phi_value:
                temp_phi_values = phi_value.split('=>')
                phi_value = temp_phi_values[0]
                normalize_time = temp_phi_values[-1]
            else:
                normalize_time = phi_value
        try:
            matches = [(match.start(), match.end()) for match in re.finditer(phi_value, sentence)]
        except:
            continue
        for start, end in matches:
            if start == end:
                continue
            item_dict = {
                        'phi' : phi_key,
                        'st_idx' : start + int(boundary),
                        'ed_idx' : end + int(boundary),
                        'entity' : phi_value
                        }
            if normalize_time is not None:
                item_dict['normalize_time'] = normalize_time
            anno_list.append(item_dict)
    return anno_list

def predict_data(model, tokenizer, input):
    template = "<|endoftext|> __CONTENT__\n\n####\n\n"
    
    #Fix none issue
    seeds = []
    for data in input:
        if data and data['content'] is not None:
            seeds.append(template.replace("__CONTENT__", data['content']))

    sep, eos, pad = tokenizer.sep_token, tokenizer.eos_token, tokenizer.pad_token
    pad_idx = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    # Set model in eval mode
    model.eval()
    texts = tokenizer(seeds, return_tensors = "pt", padding = True).to(TORCH_DEVICE)
    outputs = []
    with torch.cuda.amp.autocast():
        output_tokens = model.generate(**texts, max_new_tokens = 400, pad_token_id = pad_idx,
                                        eos_token_id = tokenizer.convert_tokens_to_ids(eos))
        preds = tokenizer.batch_decode(output_tokens)
        for idx , pred in enumerate(preds):
            if "NULL" in pred:
                continue
            phi_infos = pred[pred.index(sep) + len(sep):].replace(pad, "").replace(eos, "").strip()

            for annotation in get_anno_format(input[idx]['content'] , phi_infos , input[idx]['idx']):
                normalized_value = f"\t{annotation['normalize_time']}" if len(annotation) > 4 else ""
                outputs.append(f"{input[idx]['fid']}\t{annotation['phi']}\t{annotation['st_idx']}\t{annotation['ed_idx']}\t{annotation['entity']}{normalized_value}")

    return outputs

In [None]:
# Predict data and write output
with open(SAVE_LOCATION_WITH_FILENAME, 'w', encoding = 'utf8') as f:
    for i in tqdm(range(0, len(valid_list), PREDICT_BATCH_SIZE)):
        with torch.no_grad():
            seeds = valid_list[i:i + PREDICT_BATCH_SIZE]
            outputs = predict_data(model, tokenizer, input = seeds)
            for o in outputs:
                f.write(o)
                f.write('\n')