# Round 2 Model outputs

For Round 2 the following adjustments were made from Round 2:

__NER & CR:__

- Because both models are using spaCy both components can be run simultaneously
- Etc - add more descriptions!

__REX:__

- Self-relations were removed
- The alternate_name relation for Flair was re-included as it proved very useful in disambiguation

When running on GPU one can use ```watch -n 1 nvidia-smi``` to monitor GPU usage.

## Import required libraries

In [1]:
import json
import time
import pandas as pd
import pickle
import torch
from kg_builder import kg
from kg_builder import ner
from kg_builder import cr
from kg_builder import rex
from kg_builder import get_wd_relation_data
from kg_builder import chunk_long_articles

## Import data and make Articles

In [2]:
# Import (sample) data
df = pd.read_parquet('source_data/sample_text_30.pq')

In [3]:
# Make a list of Article instances
articles = kg.make_articles(df=df)

# Just get 3 to test with on CPU - comment out for full run
# articles = articles[0:3]

## Run NER and CR on articles

In [4]:
# Set the model name for named entity recognition and coreference resolution
cr_model_name = 'fastcoref' # 'lingmess'

cr_tagger, cr_model_name = cr.setup_cr_tagger(model_name = cr_model_name)

# And then run the model to add NER's to the articles
num_articles = len(articles)
batch_size = 50
start_indices = list(range(0, num_articles, 50))
end_indices = start_indices[1:] + [num_articles]
batches = list(zip(start_indices, end_indices))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

start_time = time.time()    
for batch in batches:
    cr.get_ner_cr_data(articles = articles[batch[0]:batch[1]],cr_tagger = cr_tagger)
    if device == 'cuda':
        torch.cuda.empty_cache()
    print(f'''>====== {batch[1]} articles processed ======<''')
end_time = time.time()
time_difference = end_time - start_time
print(time_difference)

08/02/2024 11:57:47 - INFO - 	 missing_keys: []
08/02/2024 11:57:47 - INFO - 	 unexpected_keys: []
08/02/2024 11:57:47 - INFO - 	 mismatched_keys: []
08/02/2024 11:57:47 - INFO - 	 error_msgs: []
08/02/2024 11:57:47 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M
08/02/2024 11:57:50 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 11:57:53 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 11:57:54 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 11:57:57 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 11:57:57 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 11:58:00 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 11:58:01 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 11:58:04 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 11:58:05 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 11:58:08 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 11:58:10 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 11:58:13 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

27.239612579345703


## Run RE on articles

In [5]:
# Run Rebel to get the main relations of interest
rex_model_name = 'rebel'

if rex_model_name == 'rebel':
    # Setup the required RE tagger
    rex_tagger, rex_tokenizer, device, rex_model_name = rex.setup_rex_tagger(model_name = rex_model_name)

    # And then run the model to add REs to the articles
    start_time = time.time()
    for i, article in enumerate(articles):
        chunk_boundaries = chunk_long_articles(article.article_text, max_chunk_size = 20000)
        for chunk in chunk_boundaries:
            rex.rebel_get_relations(article = article, rex_tokenizer = rex_tokenizer, \
                                    rex_tagger =  rex_tagger, device = device, chunk = chunk)
            rex.remove_self_relations(article = article)
        # Clear the CUDA cache every 5 articles
        if device == 'cuda' and (i + 1) % 5 == 0:
            torch.cuda.empty_cache()
        if (i + 1) % 50 == 0:
            print(f'''>====== {i + 1} articles processed ======<''')
    print(f'''>====== {i + 1} articles processed ======<''')
    end_time = time.time()
    time_difference = end_time - start_time
    print(time_difference)

25.729116439819336


In [6]:
# Run Flair to get alternate_name relations
rex_model_name = 'flair'

if rex_model_name == 'flair':
    # Setup the required RE tagger
    rex_tagger, ner_tagger, splitter, device, model_name = rex.setup_rex_tagger(model_name = rex_model_name)
    
    # And then run the model to add REs to the articles
    start_time = time.time()
    for i, article in enumerate(articles):
        rex.flair_get_relations(article = article, splitter  = splitter, ner_tagger  = ner_tagger, \
                                rex_tagger = rex_tagger, device = device, restricted = True)
        rex.remove_self_relations(article = article)
        # Clear the CUDA cache every 5 articles
        if device == 'cuda' and (i + 1) % 5 == 0:
            torch.cuda.empty_cache()
        if (i + 1) % 50 == 0:
            print(f'''>====== {i + 1} articles processed ======<''')
    print(f'''>====== {i + 1} articles processed ======<''')
    end_time = time.time()
    time_difference = end_time - start_time
    print(time_difference)     

2024-08-02 11:58:56,783 SequenceTagger predicts: Dictionary with 76 tags: <unk>, O, B-CARDINAL, E-CARDINAL, S-PERSON, S-CARDINAL, S-PRODUCT, B-PRODUCT, I-PRODUCT, E-PRODUCT, B-WORK_OF_ART, I-WORK_OF_ART, E-WORK_OF_ART, B-PERSON, E-PERSON, S-GPE, B-DATE, I-DATE, E-DATE, S-ORDINAL, S-LANGUAGE, I-PERSON, S-EVENT, S-DATE, B-QUANTITY, E-QUANTITY, S-TIME, B-TIME, I-TIME, E-TIME, B-GPE, E-GPE, S-ORG, I-GPE, S-NORP, B-FAC, I-FAC, E-FAC, B-NORP, E-NORP, S-PERCENT, B-ORG, E-ORG, B-LANGUAGE, E-LANGUAGE, I-CARDINAL, I-ORG, S-WORK_OF_ART, I-QUANTITY, B-MONEY
43.187543869018555


## Write to Pickle

(We don't need the JSON format now as we are done with Label Studio)

In [7]:
with open('model_outputs/round2/results.pkl', 'wb') as file:
    # Write the objects to the file
    pickle.dump(articles, file)