# Round 1 Initial model outputs

For Round 1 the following needed to be run and evaluated:

__NER:__

- spaCy en_core_web_trf (used as basis for annotation)
- flair/ner-english-ontonotes-large

__CR:__

- fastcoref (used as basis for annotation)
- LingMess

__REX:__

- Babelscape/rebel-large (used as basis for annotation)
- Flair (only alternat_name included in annotations)

The below code reads the selected sample data into a df which includes the following columns: 

- Id: unique Id for each article, e.g. _ed94c34a-8499-44f9-afb4-f8df96bb8843_
- Permatitle: final part of article URL, e.g. _zondo-commission-to-issue-a-summons-for-jacob-zuma-to-appear-20200110_
- SampleType: one of _general_, _analysis_, or _opinion_ indicating the type of article
- AllText: a combination of _Title_, _Synopsis_ and _CleanBody_ (article text pre-stripped of HTML tags)
- Split: one of _train_ or _test_

It then generates model outputs for each of the 3 tasks listed above (adjust model names as required) and generates label-studio ready json files to enable annotation.

When running on GPU one can use ```watch -n 1 nvidia-smi``` to monitor GPU usage.

## Import required libraries

In [1]:
import json
import time
import pandas as pd
import torch
from kg_builder import kg
from kg_builder import ner
from kg_builder import cr
from kg_builder import rex
from kg_builder import get_wd_relation_data
from kg_builder import chunk_long_articles

## Import data and make Articles

In [2]:
# Import sample data
df = pd.read_parquet('source_data/sample_text_30.pq')

In [3]:
# Make a list of Article instances
articles = kg.make_articles(df=df)

# Just get 3 to test with on CPU - comment out for full run
# articles = articles[0:3]

## Run NER on articles

In [4]:
# Set the model name for named entity recognition as required
ner_model_name = 'spacy' # 'flair'

ner_tagger, ner_model_name = ner.setup_ner_tagger(model_name = ner_model_name)

# And then run the model to add NER's to the articles
num_articles = len(articles)
batch_size = 50
start_indices = list(range(0, num_articles, 50))
end_indices = start_indices[1:] + [num_articles]
batches = list(zip(start_indices, end_indices))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

start_time = time.time()
for batch in batches:
    ner.get_entities(articles = articles[batch[0]:batch[1]], model_name = ner_model_name , ner_tagger = ner_tagger)
    if device == 'cuda':
        torch.cuda.empty_cache()
    print(f'''>====== {batch[1]} articles processed ======<''')
end_time = time.time()
time_difference = end_time - start_time
print(time_difference)

articles[1].print_named_entities()

6.897720575332642

Article b98bba34-c5d7-440b-b7a5-e365fabf4bc3
WATCH LIVE | State capture inquiry continues
--------------------------------------------
PwC, ORG, [68:71]
Pule Mothibe, PERSON, [80:92]
PricewaterhouseCoopers, ORG, [183:205]
Pule Mothibe, PERSON, [214:226]
PwC, ORG, [228:231]
South Africa, GPE, [232:244]
Pule Mothibe, PERSON, [253:265]
SAA, ORG, [367:370]
PwC, ORG, [422:425]
SAA, ORG, [434:437]
Mothibe, PERSON, [484:491]
SAA, ORG, [537:540]
PricewaterhouseCoopers (, ORG, [741:765]
PwC, ORG, [765:768]
Pule Mothibe, PERSON, [778:790]
Mothibe, PERSON, [793:800]
PwC, ORG, [805:808]
South African Airways, ORG, [832:853]
SAA, ORG, [855:858]
Mothibe, PERSON, [916:923]
Kate Hofmeyr, PERSON, [967:979]
PwC, ORG, [1023:1026]
SAA, ORG, [1067:1070]
Mothibe, PERSON, [1126:1133]
SAA, ORG, [1225:1228]
PwC, ORG, [1303:1306]
SAA, ORG, [1425:1428]
Transnet, ORG, [1589:1597]
Eskom, ORG, [1620:1625]
Mafika Mkwanazi, PERSON, [1639:1654]
SABC, ORG, [1674:1678]


## Run CR on articles

In [5]:
# Set the model name for coreference resolution as required
cr_model_name = 'fastcoref' # 'lingmess'

# Setup the required NER tagger
cr_tagger, cr_model_name = cr.setup_cr_tagger(model_name = cr_model_name)

# And then run the model to add NER's to the articles
num_articles = len(articles)
batch_size = 50
start_indices = list(range(0, num_articles, 50))
end_indices = start_indices[1:] + [num_articles]
batches = list(zip(start_indices, end_indices))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

start_time = time.time()    
for batch in batches:
    cr.get_clusters(articles = articles[batch[0]:batch[1]], model_name = cr_model_name , cr_tagger = cr_tagger)
    if device == 'cuda':
        torch.cuda.empty_cache()
    print(f'''>====== {batch[1]} articles processed ======<''')
end_time = time.time()
time_difference = end_time - start_time
print(time_difference)

articles[1].print_cr_clusters(post_processing = False)

08/02/2024 12:08:32 - INFO - 	 missing_keys: []
08/02/2024 12:08:32 - INFO - 	 unexpected_keys: []
08/02/2024 12:08:32 - INFO - 	 mismatched_keys: []
08/02/2024 12:08:32 - INFO - 	 error_msgs: []
08/02/2024 12:08:32 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M
08/02/2024 12:08:33 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 12:08:36 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 12:08:36 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 12:08:39 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 12:08:40 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 12:08:43 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 12:08:44 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 12:08:47 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 12:08:48 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 12:08:51 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

08/02/2024 12:08:53 - INFO - 	 Tokenize 5 inputs...


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

08/02/2024 12:08:56 - INFO - 	 ***** Running Inference on 5 texts *****


Inference:   0%|          | 0/5 [00:00<?, ?it/s]

24.030704498291016

Article b98bba34-c5d7-440b-b7a5-e365fabf4bc3
WATCH LIVE | State capture inquiry continues
--------------------------------------------

Cluster 0: 
State capture inquiry, [13:34]
The state capture inquiry, [93:118]
The state capture inquiry, [651:676]
The state capture inquiry, [1526:1551]

Cluster 1: 
testimony from PwC auditor Pule Mothibe, [53:92]
aviation-related testimony from PricewaterhouseCoopers auditor Pule Mothibe, [151:226]
aviation-related testimony from PricewaterhouseCoopers (PwC) auditor Pule Mothibe, [709:790]

Cluster 2: 
PwC auditor Pule Mothibe, [68:92]
PricewaterhouseCoopers auditor Pule Mothibe, [183:226]
PwC South Africa auditor Pule Mothibe, [228:265]
his, [309:312]
Mothibe, [484:491]

Cluster 3: 
SAA management, [367:381]
it, [417:419]
SAA management, [1225:1239]

Cluster 4: 
PwC, [68:71]
PwC, [422:425]
PwC's, [805:810]
PwC, [1023:1026]
it, [1088:1090]

Cluster 5: 
SAA, [367:370]
SAA's, [434:439]
SAA, [537:540]
SAA, [1067:1070]
SAA, [1225:12

## Run RE on articles

In [6]:
# Set the model name for relation extraction
rex_model_name = 'rebel' # 'flair'

if rex_model_name == 'rebel':
    # Setup the required RE tagger
    rex_tagger, rex_tokenizer, device, rex_model_name = rex.setup_rex_tagger(model_name = rex_model_name)

    # And then run the model to add REs to the articles
    start_time = time.time()
    for i, article in enumerate(articles):
        chunk_boundaries = chunk_long_articles(article.article_text, max_chunk_size = 20000)
        for chunk in chunk_boundaries:
            rex.rebel_get_relations(article = article, rex_tokenizer = rex_tokenizer, \
                                    rex_tagger =  rex_tagger, device = device, chunk = chunk)
        # Clear the CUDA cache every 5 articles
        if device == 'cuda' and (i + 1) % 5 == 0:
            torch.cuda.empty_cache()
        if (i + 1) % 50 == 0:
            print(f'''>====== {i + 1} articles processed ======<''')
    print(f'''>====== {i + 1} articles processed ======<''')
    end_time = time.time()
    time_difference = end_time - start_time
    print(time_difference)
        
elif rex_model_name == 'flair':
    # Setup the required RE tagger
    rex_tagger, ner_tagger, splitter, device, model_name = rex.setup_rex_tagger(model_name = rex_model_name)
    
    # And then run the model to add REs to the articles
    start_time = time.time()
    for i, article in enumerate(articles):
        rex.flair_get_relations(article = article, splitter  = splitter, ner_tagger  = ner_tagger, \
                                rex_tagger = rex_tagger, device = device)
        if (i + 1) % 50 == 0:
            print(f'''>====== {i + 1} articles processed ======<''')
    print(f'''>====== {i + 1} articles processed ======<''')
    end_time = time.time()
    time_difference = end_time - start_time
    print(time_difference)
    
articles[1].print_relations()

25.44956374168396

Article b98bba34-c5d7-440b-b7a5-e365fabf4bc3
WATCH LIVE | State capture inquiry continues
--------------------------------------------
Pule Mothibe >> employer >> PwC South Africa
[80:92], [228:244]                
Pule Mothibe >> employer >> PricewaterhouseCoopers
[80:92], [183:205]                
Pule Mothibe >> employer >> PricewaterhouseCoopers
[778:790], [741:763]                
Pule Mothibe >> employer >> PwC
[778:790], [765:768]                
Pule Mothibe >> employer >> PricewaterhouseCoopers (PwC)
[778:790], [741:769]                
Transnet >> chairperson >> Mafika Mkwanazi
[1589:1597], [1639:1654]                
Transnet >> parent organization >> Eskom
[1589:1597], [1620:1625]                
Mafika Mkwanazi >> employer >> Transnet
[1639:1654], [1589:1597]                


## Write to Label Studio json format

In [7]:
# Task is 'named_entities'
filename = f'''outputs/round1/sample_ner_30_{ner_model_name}.json'''
with open(filename, 'w') as f:
    json.dump([article.to_labelstudio('named_entities') for article in articles], f, indent=4)

In [8]:
# Task is 'cr_clusters'
filename = f'''outputs/round1/sample_cr_30_{cr_model_name}.json'''
with open(filename, 'w') as f:
    json.dump([article.to_labelstudio('cr_clusters') for article in articles], f, indent=4)

In [9]:
# Task is 'relations'
filename = f'''outputs/round1/sample_re_30_{rex_model_name}.json'''
with open(filename, 'w') as f:
    json.dump([article.to_labelstudio('relations') for article in articles], f, indent=4)