In [None]:
# CLONE REPOSITORY
!git clone --branch T5 https://github.com/rolysr/medical-knowledge-discoverer

In [None]:
# MOVE TO PROJECT
%cd medical-knowledge-discoverer/

In [None]:
# INSTALLS
%pip install simplet5 fasttext
!python -m spacy download es_core_news_sm en_core_web_sm

In [None]:
# IMPORTS
import os
from pathlib import Path
from simplet5 import SimpleT5
from rich.progress import track
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score

# FROM PROJECT
from utils.anntools import Collection


# MODELS
from models.T5.t5 import T5
from models.NER.ner import NER


In [2]:
# T5 MODEL
t5 = T5()

# OUTPUT
output_path = Path('./output')
os.makedirs(output_path, exist_ok=True)

# TRAIN PATH
train_path = Path('./datasets/train')
csv_train_file = './models/T5/re_train.csv'

# GENERATE TRAIN DATA
test_collection = Collection().load_dir(train_path)
train_dataset = t5.generate_t5_input_output_format(test_collection)
MAX_INPUT_TOKENS = max([len(data[0]) for data in train_dataset])
MAX_OUTPUT_TOKENS = max([len(data[1]) for data in train_dataset])
t5.generate_csv(train_dataset, csv_train_file)

In [None]:
# TRAIN MODEL
# TRAIN A MODEL FROM SCRATCH AND SAVE EACH EPOCH IN DIFFERENT FILES
model = SimpleT5()

t5.generate_csv(train_dataset, str(csv_train_file))
df = t5.load_csv(str(csv_train_file))
train_df, test_df = train_test_split(df, test_size=0.1)

model.from_pretrained(model_type="t5", model_name="t5-base")

print('Training...')
model.train(train_df=train_df,
            eval_df=test_df, 
            source_max_token_len=MAX_INPUT_TOKENS + 50, 
            target_max_token_len=MAX_OUTPUT_TOKENS + 8, 
            batch_size=8,
            max_epochs=4,
            use_gpu=True,
            outputdir=output_path
)

In [None]:
# # SELECT MODEL
!ls ./output

In [None]:
# SELECT MODEL
trained_model = '< model name >'

In [None]:
# LOAD TRAINED MODEL
model = SimpleT5()
model.load_model('t5', trained_model, use_gpu=False)
t5.model = model

In [None]:
# NER MODEL
ner = NER()

# TRAINING NER MODEL
train_collection = Collection().load_dir(train_path)
ner.train(train_collection)


In [16]:
# EVALUATION
def eval(test_collection: Collection, ner_collection: Collection, model):
    
    CORRECT, MISSING, SPURIOUS, INCORRECT = 0, 0, 0, 0

    for n, sentences in enumerate(zip(test_collection.sentences, ner_collection.sentences)):
        print('n', n+1)
        test_sentence, ner_sentence = sentences
        
        test = {}
        for test_relation in test_sentence.relations:
            origin = test_relation.from_phrase
            origin_text = origin.text.lower()
            destination = test_relation.to_phrase
            destination_text = destination.text.lower()

            input_text = t5.get_marked_sentence_t5_input_format(test_sentence.text, origin_text, origin.label, destination_text, destination.label)
            output_text = t5.get_t5_output_format(origin_text, origin.label, destination_text, destination.label, test_relation.label)
            
            test[test_relation] = output_text

        results= {}
        for ner_relation in ner_sentence.relations:
            origin = ner_relation.from_phrase
            origin_text = origin.text.lower()
            destination = ner_relation.to_phrase
            destination_text = destination.text.lower()

            #making the pair
            input_text = t5.get_marked_sentence_t5_input_format(ner_sentence.text, origin_text, origin.label, destination_text, destination.label)

            results[ner_relation] = model.predict(input_text)[0]
        
        
        for i in test.copy():
            if results.get(i) is not None:
                if results[i] == test[i]:
                    CORRECT += 1
                    results.pop(i)
                    test.pop(i)
                else:
                    INCORRECT += 1
                    results.pop(i)
                    test.pop(i)
        
        SPURIOUS += len(results)
        MISSING += len(test)


    return CORRECT, MISSING, SPURIOUS, INCORRECT

In [None]:
# RE EVALUATION
test_path = Path('./datasets/test/scenario1-main')
csv_test_file = Path('models/T5/re_test.csv')
test_collection = Collection().load_dir(test_path)

# EVALUATE NER
ner_collection = ner.run(test_collection)

print(len(test_collection), len(ner_collection))
CORRECT, MISSING, SPURIOUS, INCORRECT = eval(test_collection, ner_collection, model)

In [None]:
# SHOW RESULTS
precision = CORRECT / (CORRECT + MISSING + INCORRECT) if (CORRECT + MISSING + INCORRECT) > 0 else 0
recall = CORRECT / (CORRECT + SPURIOUS + INCORRECT) if (CORRECT + SPURIOUS + INCORRECT) > 0 else 0
f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

print("Precision:", precision)
print('Recall:', recall)
print('F1 score:', f1)