In [None]:
# CLONE REPOSITORY
!git clone --branch main https://github.com/rolysr/medical-knowledge-discoverer

In [None]:
# MOVE TO PROJECT
%cd medical-knowledge-discoverer/

In [None]:
# INSTALLS
%pip install simplet5

In [None]:
# IMPORTS
import os
from pathlib import Path
from simplet5 import SimpleT5
from rich.progress import track
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score

# FROM PROJECT
from utils.anntools import Collection
from models.T5.t5 import T5


In [None]:
# T5 MODEL
t5 = T5()

# OUTPUT
output_path = Path('./output')
os.makedirs(output_path, exist_ok=True)

# TRAIN PATH
train_path = Path('./datasets/train')
csv_train_file = './models/T5/re_train.csv'

# GENERATE TRAIN DATA
collection = Collection().load_dir(train_path)
train_dataset = t5.generate_t5_input_output_format(collection)
MAX_INPUT_TOKENS = max([len(data[0]) for data in train_dataset])
MAX_OUTPUT_TOKENS = max([len(data[1]) for data in train_dataset])
t5.generate_csv(train_dataset, csv_train_file)

In [None]:
# TRAIN MODEL
# TRAIN A MODEL FROM SCRATCH AND SAVE EACH EPOCH IN DIFFERENT FILES
model = SimpleT5()

t5.generate_csv(train_dataset, str(csv_train_file))
df = t5.load_csv(str(csv_train_file))
train_df, test_df = train_test_split(df, test_size=0.1)

model.from_pretrained(model_type="t5", model_name="t5-base")

print('Training...')
model.train(train_df=train_df,
            eval_df=test_df, 
            source_max_token_len=MAX_INPUT_TOKENS + 50, 
            target_max_token_len=MAX_OUTPUT_TOKENS + 8, 
            batch_size=8,
            max_epochs=4,
            use_gpu=True,
            outputdir=output_path
)

In [None]:
# # SELECT MODEL
!ls ./output

In [None]:
# SELECT MODEL
trained_model = '< model name >'

In [None]:
# LOAD TRAINED MODEL
model = SimpleT5()
model.load_model('t5', trained_model, use_gpu=False)
t5.model = model

In [None]:
# EVALUATION
def eval(collection):
    y_test = []
    y_pred = []
    for sentence in track(collection.sentences, description='evaluating...'):
        for relation in sentence.relations:
            origin = relation.from_phrase
            origin_text = origin.text.lower()
            destination = relation.to_phrase
            destination_text = destination.text.lower()

            #making the pair
            input_text = t5.get_marked_sentence_t5_input_format(sentence.text, origin_text, origin.label, destination_text, destination.label)
            output_text = t5.get_t5_output_format(origin_text, origin.label, destination_text, destination.label, relation.label)
            y_test.append(output_text)
            y_pred.append(model.predict(input_text)[0])

    return y_test, y_pred

In [None]:
# RE EVALUATION
test_path = Path('./datasets/test/scenario1-main')
csv_test_file = Path('models/T5/re_test.csv')

collection = Collection().load_dir(test_path)
test_dataset = t5.generate_t5_input_output_format(collection)
t5.generate_csv(test_dataset, csv_test_file)

y_test, y_pred = eval(collection)

In [None]:
# SHOW RESULTS
print("Precision:", precision_score(y_test, y_pred, average="weighted", zero_division=1))
print('Recall:', recall_score(y_test, y_pred, average="weighted", zero_division=1))
print('F1 score:', f1_score(y_test, y_pred, average="weighted", zero_division=1))
