In [1]:
import json
import pandas as pd
from development.datasets.osdg_dataset import load_osdg_data
from development.datasets.uclmodules_dataset import load_uclmodules_data
from development.datasets.videscription_dataset import load_videscription_data
from development.datasets.relx_dataset import load_relx_data
from development.models.Bert import Bert
from development.models.BertMultiLabel import BertMultiLabel
from development.models.RobertaNER import RobertaNER
from development.train_model import fine_tune_transformer
from development.pipelines import full_pipe
from development.scrape.RelxScraper import RelxScraper
from development.utils import parse_sdg_id

In [2]:
%load_ext autoreload
%autoreload 2
pd.set_option('max_colwidth', None)

In [3]:
with open('config.json', 'r') as file:
    CONFIG = json.load(file)
    dev_config = CONFIG['development']


In [4]:
osdg_data = load_osdg_data(
    dev_config['osdg_data_path'],
    training=False,
    filter_agreement=False
)

In [367]:
osdg_data.groupby(['sdg']).count()['text']

sdg
1     2734
2     2457
3     2689
4     3740
5     4338
6     2815
7     3048
8     1509
9     2105
10    2032
11    2277
12    1108
13    2102
14    1141
15    2143
16    5451
Name: text, dtype: int64

In [15]:
osdg_data = load_osdg_data(
    dev_config['osdg_data_path'],
    training=True,
    filter_agreement=False
)

In [11]:
ucl_data = load_uclmodules_data(dev_config['uclmodules_data_path'], only_labled=True)

In [5]:
videscription_data = load_videscription_data(
    dev_config['videscription_data_path']
)

In [None]:
bert = Bert()

In [None]:
result = fine_tune_transformer(
    bert.model,
    bert.tokenizer,
    bert.tokenizer_args,
    data=osdg_data,
    dev_config=dev_config
)

In [18]:
bert = Bert('./development/weights/Bert-5/checkpoint-6384/')

In [19]:
cls_report = bert.evaluate(osdg_data['train'])
print(cls_report)

              precision    recall  f1-score   support

           0       0.87      0.92      0.89      1914
           1       0.90      0.94      0.92      1720
           2       0.96      0.97      0.96      1882
           3       0.95      0.97      0.96      2618
           4       0.96      0.95      0.95      3037
           5       0.92      0.93      0.93      1970
           6       0.94      0.95      0.94      2134
           7       0.88      0.82      0.85      1056
           8       0.91      0.92      0.91      1473
           9       0.89      0.80      0.84      1422
          10       0.93      0.95      0.94      1594
          11       0.92      0.93      0.92       776
          12       0.93      0.94      0.94      1471
          13       0.97      0.92      0.95       799
          14       0.95      0.91      0.93      1500
          15       1.00      1.00      1.00      3816

    accuracy                           0.94     29182
   macro avg       0.93   

In [6]:
cls_report = bert.evaluate(osdg_data['valid'])
print(cls_report)

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


              precision    recall  f1-score   support

           0       0.89      0.93      0.91       410
           1       0.92      0.95      0.94       368
           2       0.94      0.98      0.96       403
           3       0.95      0.96      0.96       561
           4       0.96      0.94      0.95       650
           5       0.92      0.97      0.94       422
           6       0.94      0.93      0.94       457
           7       0.82      0.80      0.81       227
           8       0.90      0.91      0.91       316
           9       0.92      0.78      0.84       305
          10       0.94      0.96      0.95       342
          11       0.94      0.92      0.93       166
          12       0.94      0.96      0.95       316
          13       0.97      0.91      0.94       171
          14       0.94      0.93      0.93       322
          15       0.99      1.00      1.00       817

    accuracy                           0.94      6253
   macro avg       0.93   

In [20]:
cls_report = bert.evaluate(osdg_data['test'])
print(cls_report)

              precision    recall  f1-score   support

           0       0.86      0.93      0.89       410
           1       0.93      0.96      0.94       369
           2       0.95      0.98      0.96       404
           3       0.94      0.96      0.95       561
           4       0.95      0.93      0.94       651
           5       0.91      0.94      0.93       423
           6       0.96      0.95      0.95       457
           7       0.88      0.81      0.84       226
           8       0.94      0.93      0.93       316
           9       0.90      0.79      0.84       305
          10       0.94      0.95      0.94       341
          11       0.93      0.90      0.92       166
          12       0.91      0.94      0.92       315
          13       0.98      0.96      0.97       171
          14       0.94      0.92      0.93       321
          15       1.00      1.00      1.00       818

    accuracy                           0.94      6254
   macro avg       0.93   

In [48]:
ner_model = RobertaNER()

In [27]:
ner_model.print_entities(osdg_data['train'][0][14])

ORG: The World Resources Institute
ORG: World Resources Institute


In [76]:
df = full_pipe(bert, ner_model, texts=[osdg_data['train'][0][7]])

Unnamed: 0,Text,SDG,Entities,Sentiment
0,"This is why the Sustainable Development of Protected Areas System of Ethiopia was set up, with support from the Global Environment Fund and UNDP. The project is spearheading a suite of interventions, focusing on the national system in terms of capacity building and training, and integrating the protected area system into mainstream development. Since the initiation of the project in 2008, valuation exercises have found that the main value of protected areas is in the environmental services that they provide to poor rural communities, many of which are food-insecure, protected areas were incorporated into the Ethiopia Poverty Strategy, and the legal boundaries of the protected area system were strengthened by supporting the demarcation and gazettement of four areas through a highly consultative process (UNDP, n.d.).",Climate Action,ORG: the Sustainable Development of Protected Areas System of Ethiopia - ORG: the Global Environment Fund - ORG: UNDP - ORG: UNDP -,


In [138]:
relx_scraper = RelxScraper()
relx_scraper.scrape_data(start=0)
relx_scraper.save_as_csv('./data/relx_data.csv')

In [6]:
relx_training_data = load_relx_data(
    data_path=dev_config['relx_data_path'],
    training=True
)

In [5]:
bert_multilabel = BertMultiLabel(
    './development/weights/Bert-5/checkpoint-6384/'
)

In [None]:
results = fine_tune_transformer(
    bert_multilabel.model,
    bert_multilabel.tokenizer,
    bert_multilabel.tokenizer_args,
    data=relx_training_data,
    dataset='relx',
    dev_config=dev_config
)

In [43]:
bert_multilabel = BertMultiLabel(
    './development/weights/Bert-Multilabel-5/checkpoint-1400/'
)
bert_multilabel.to_gpu()

In [7]:
cls_report = bert_multilabel.evaluate(relx_training_data['train'])
print(cls_report)

              precision    recall  f1-score   support

           0       1.00      0.36      0.53        14
           1       0.91      0.96      0.94        77
           2       0.99      0.93      0.96       543
           3       1.00      0.61      0.76        23
           4       0.98      0.89      0.93        53
           5       0.97      0.93      0.95        99
           6       0.90      0.88      0.89        42
           7       1.00      0.33      0.50        21
           8       0.80      0.69      0.74        51
           9       0.95      0.62      0.75       166
          10       0.99      0.83      0.90        99
          11       0.90      0.70      0.79        61
          12       0.99      0.83      0.90       222
          13       0.98      0.87      0.92        54
          14       0.80      0.84      0.82        77
          15       0.68      0.48      0.57        27

   micro avg       0.95      0.83      0.89      1629
   macro avg       0.93   

In [10]:
cls_report = bert_multilabel.evaluate(relx_training_data['valid'])
print(cls_report)

              precision    recall  f1-score   support

           0       1.00      0.00      0.00         1
           1       0.71      0.83      0.77        12
           2       0.92      0.86      0.89        98
           3       0.75      0.60      0.67         5
           4       0.60      0.33      0.43         9
           5       0.60      0.75      0.67        16
           6       0.60      0.43      0.50         7
           7       1.00      0.00      0.00         1
           8       0.33      0.38      0.35         8
           9       0.85      0.48      0.61        23
          10       0.67      0.50      0.57        12
          11       0.50      0.33      0.40        12
          12       0.87      0.60      0.71        43
          13       0.75      0.60      0.67         5
          14       0.47      0.54      0.50        13
          15       0.50      0.40      0.44         5

   micro avg       0.77      0.66      0.71       270
   macro avg       0.69   

In [15]:
cls_report = bert_multilabel.evaluate(relx_training_data['test'])
print(cls_report)

              precision    recall  f1-score   support

           0       1.00      0.00      0.00         1
           1       0.79      0.73      0.76        15
           2       0.97      0.85      0.90       100
           3       1.00      0.00      0.00         6
           4       0.75      0.60      0.67        10
           5       0.91      0.83      0.87        12
           6       0.70      0.70      0.70        10
           7       1.00      0.00      0.00         4
           8       0.17      0.12      0.14         8
           9       0.62      0.45      0.52        29
          10       0.40      0.27      0.32        15
          11       0.67      0.40      0.50        10
          12       0.89      0.63      0.74        38
          13       0.89      0.80      0.84        10
          14       0.60      0.50      0.55        12
          15       1.00      0.29      0.44         7

   micro avg       0.82      0.63      0.71       287
   macro avg       0.77   

In [18]:
cls_report = bert_multilabel.evaluate_single_label(
    osdg_data['test'],
    mode='exact_match'
)
print(cls_report)

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


              precision    recall  f1-score   support

           0       0.94      0.61      0.74       410
           1       0.92      0.88      0.90       369
           2       0.62      0.99      0.76       404
           3       0.92      0.90      0.91       561
           4       0.93      0.89      0.91       651
           5       0.91      0.85      0.88       423
           6       0.97      0.81      0.88       457
           7       0.84      0.64      0.72       226
           8       0.90      0.84      0.87       316
           9       0.51      0.87      0.65       305
          10       0.78      0.93      0.84       341
          11       0.97      0.61      0.75       166
          12       0.71      0.91      0.80       315
          13       0.96      0.87      0.91       171
          14       0.90      0.88      0.89       321
          15       1.00      0.82      0.90       818

    accuracy                           0.84      6254
   macro avg       0.86   

In [19]:
accuracy = bert_multilabel.evaluate_single_label(
    osdg_data['test'],
    mode='included'
)
accuracy

0.9267668692037097

In [4]:
ucl_data = load_uclmodules_data(
    dev_config['uclmodules_data_path'],
    only_labled=True,
    evaluation=True
)

In [59]:
cls_report = bert_multilabel.evaluate(ucl_data)
print(cls_report)

              precision    recall  f1-score   support

           0       0.80      0.09      0.16        87
           1       0.30      0.58      0.40        12
           2       0.84      0.90      0.87      1015
           3       0.17      0.88      0.29       258
           4       0.37      0.72      0.49       128
           5       0.26      0.83      0.40        18
           6       0.62      0.62      0.62        76
           7       0.70      0.35      0.47       339
           8       0.62      0.67      0.64       588
           9       0.07      0.60      0.13        53
          10       0.68      0.42      0.52       506
          11       0.10      0.33      0.16        24
          12       0.26      0.74      0.39        68
          13       0.58      0.55      0.56        20
          14       0.21      0.67      0.32        18
          15       0.66      0.62      0.64       667

   micro avg       0.48      0.66      0.56      3877
   macro avg       0.45   

'Cardiac Critical Care (CHLD0081) \n Summary\nThis module will introduce principles of paediatric cardiology and cardiac intensive care. Over the course of the week, we will discuss cardiac anatomy and physiology, before moving on to post-natal management of cardiac disease, and subsequent management at specialist centres from the perspectives of the cardiologist, the anaesthetic and surgical teams, as well as the cardiac intensive care team. We will also spend some time looking at arrhythmia, pulmonary hypertension and heart failure, with a focus on underlying pathophysiology and management strategies (including mechanical circulatory support).\nYou will also critically review literature and latest advancements in the field, and present this in a poster with an oral session.\nLearning objectives and outcomes\nAfter taking this module, you will be able to:\nDescribe the underlying physiology & anatomy of paediatric cardiac disease.\nUnderstand the principles of management – both surgical and medical.\nDescribe and suggest management strategies for arrhythmia\nUnderstand the pathophysiology of pulmonary hypertension and heart failure\nDescribe management strategies for the above.\n\nWho is this module for?\n\nThis is a core module for the MSc Paediatrics and Child Health: Intensive Care and an optional module for all other MSc Paediatric and Child Health Pathways, aimed at those with a genuine interest in cardiology and cardiac critical care.\nTeaching and Learning Methods\nYou will receive 5 days of interactive online lectures/workshops, which will supplement learning through self-directed study.\n\nAssessment\n\nYou will need to create a poster and up to date literature review on a topic relevant to the module. You will also be assessed via oral presentation of the topic to the examiners and student group.'


In [114]:
ucl_data[0][590]

'Conflict of Laws (LAWS0034) \n Knowledge of the Conflict of Laws (also known as Private International Law) is essential for any lawyer who aspires to work in any area of practice that transcends national frontiers, whether as a specialist in dispute resolution or in advisory work. London is one of the leading centres for international commercial dispute resolution, and most of the commercial disputes heard in London involve foreign parties, so Conflict of Laws rules are particularly central to the work of the English commercial courts. It is a fascinating area of the law, and one of enormous practical importance as legal relationships and disputes increasingly cross borders, but also one of the most intellectually demanding.\nYou will deal principally with three separate questions which may arise in cross-border civil and commercial litigation:\njurisdiction, the question of which court may hear a dispute;\napplicable law, the question of which law or laws a court will apply to resolv

In [113]:
for label in ucl_data[1][590]:
    print(f'SDG ID: {label+1}, SDG: {parse_sdg_id(label)}')

SDG ID: 16, SDG: Peace, Justice, and Strong Institutions
SDG ID: 8, SDG: Decent Work and Economic Growth


In [49]:
predictions = bert_multilabel.predict(ucl_data[0][590])
predictions = bert_multilabel.parse_predictions(
    predictions,
    top_k=16,
    threshold=0.7
)

for prediction in predictions[0]:
    print(f'SDG ID: {prediction+1}, SDG: {parse_sdg_id(prediction)}')

SDG ID: 13, SDG: Climate Action
SDG ID: 9, SDG: Industry, Innovation, and Infrastructure


In [None]:
import torch
args = torch.load('./development/weights/Bert-5/checkpoint-1824/training_args.bin')
args