In [1]:
import json
import pandas as pd
from development.datasets.OsdgDataset import load_osdg_data
from development.datasets.uclmodules_dataset import load_uclmodules_data
from development.datasets.videscription_dataset import load_videscription_data
from development.datasets.RelxDataset import load_relx_data
from development.models.Bert import Bert
from development.models.BertMultiLabel import BertMultiLabel
from development.models.RobertaNER import RobertaNER
from development.train_model import fine_tune_transformer
from development.pipelines import full_pipe
from development.scrape.RelxScraper import RelxScraper

In [2]:
%load_ext autoreload
%autoreload 2
pd.set_option('max_colwidth', None)

In [3]:
with open('config.json', 'r') as file:
    CONFIG = json.load(file)
    dev_config = CONFIG['development']


In [4]:
osdg_data = load_osdg_data(
    dev_config['osdg_data_path'],
    training=False,
    filter_agreement=False
)

In [367]:
osdg_data.groupby(['sdg']).count()['text']

sdg
1     2734
2     2457
3     2689
4     3740
5     4338
6     2815
7     3048
8     1509
9     2105
10    2032
11    2277
12    1108
13    2102
14    1141
15    2143
16    5451
Name: text, dtype: int64

In [15]:
osdg_data = load_osdg_data(
    dev_config['osdg_data_path'],
    training=True,
    filter_agreement=False
)

In [94]:
ucl_data = load_uclmodules_data(dev_config['uclmodules_data_path'])

In [5]:
videscription_data = load_videscription_data(
    dev_config['videscription_data_path']
)

In [None]:
bert = Bert()

In [None]:
result = fine_tune_transformer(
    bert.model,
    bert.tokenizer,
    bert.tokenizer_args,
    data=osdg_data,
    dev_config=dev_config
)

In [18]:
bert = Bert('./development/weights/Bert-5/checkpoint-6384/')

In [19]:
cls_report = bert.evaluate(osdg_data['train'])
print(cls_report)

In [6]:
cls_report = bert.evaluate(osdg_data['valid'])
print(cls_report)

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


              precision    recall  f1-score   support

           0       0.89      0.93      0.91       410
           1       0.92      0.95      0.94       368
           2       0.94      0.98      0.96       403
           3       0.95      0.96      0.96       561
           4       0.96      0.94      0.95       650
           5       0.92      0.97      0.94       422
           6       0.94      0.93      0.94       457
           7       0.82      0.80      0.81       227
           8       0.90      0.91      0.91       316
           9       0.92      0.78      0.84       305
          10       0.94      0.96      0.95       342
          11       0.94      0.92      0.93       166
          12       0.94      0.96      0.95       316
          13       0.97      0.91      0.94       171
          14       0.94      0.93      0.93       322
          15       0.99      1.00      1.00       817

    accuracy                           0.94      6253
   macro avg       0.93   

In [None]:
cls_report = bert.evaluate(osdg_data['test'])
print(cls_report)

In [48]:
ner_model = RobertaNER()

In [27]:
ner_model.print_entities(osdg_data['train'][0][14])

ORG: The World Resources Institute
ORG: World Resources Institute


In [76]:
df = full_pipe(bert, ner_model, texts=[osdg_data['train'][0][7]])

Unnamed: 0,Text,SDG,Entities,Sentiment
0,"This is why the Sustainable Development of Protected Areas System of Ethiopia was set up, with support from the Global Environment Fund and UNDP. The project is spearheading a suite of interventions, focusing on the national system in terms of capacity building and training, and integrating the protected area system into mainstream development. Since the initiation of the project in 2008, valuation exercises have found that the main value of protected areas is in the environmental services that they provide to poor rural communities, many of which are food-insecure, protected areas were incorporated into the Ethiopia Poverty Strategy, and the legal boundaries of the protected area system were strengthened by supporting the demarcation and gazettement of four areas through a highly consultative process (UNDP, n.d.).",Climate Action,ORG: the Sustainable Development of Protected Areas System of Ethiopia - ORG: the Global Environment Fund - ORG: UNDP - ORG: UNDP -,


In [138]:
relx_scraper = RelxScraper()
relx_scraper.scrape_data(start=0)
relx_scraper.save_as_csv('./data/relx_data.csv')

In [4]:
relx_training_data = load_relx_data(
    data_path=dev_config['relx_data_path'],
    training=True
)

In [5]:
bert_multilabel = BertMultiLabel(
    './development/weights/Bert-5/checkpoint-6384/'
)

In [6]:
results = fine_tune_transformer(
    bert_multilabel.model,
    bert_multilabel.tokenizer,
    bert_multilabel.tokenizer_args,
    data=relx_training_data,
    dataset='relx',
    dev_config=dev_config
)



  0%|          | 0/1040 [00:00<?, ?it/s]

{'loss': 0.7096, 'learning_rate': 5.2e-06, 'epoch': 1.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.5566602263317773, 'eval_runtime': 10.3519, 'eval_samples_per_second': 17.098, 'eval_steps_per_second': 1.159, 'epoch': 1.0}
{'loss': 0.4709, 'learning_rate': 1.04e-05, 'epoch': 2.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.38561571981570186, 'eval_runtime': 10.5763, 'eval_samples_per_second': 16.736, 'eval_steps_per_second': 1.135, 'epoch': 2.0}
{'loss': 0.3158, 'learning_rate': 1.56e-05, 'epoch': 3.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.24736081546953206, 'eval_runtime': 10.3459, 'eval_samples_per_second': 17.108, 'eval_steps_per_second': 1.16, 'epoch': 3.0}
{'loss': 0.2168, 'learning_rate': 2.08e-05, 'epoch': 4.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.20990913641718967, 'eval_runtime': 10.5386, 'eval_samples_per_second': 16.795, 'eval_steps_per_second': 1.139, 'epoch': 4.0}
{'loss': 0.1782, 'learning_rate': 2.6000000000000002e-05, 'epoch': 5.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.19769713115205234, 'eval_runtime': 10.2418, 'eval_samples_per_second': 17.282, 'eval_steps_per_second': 1.172, 'epoch': 5.0}
{'loss': 0.1534, 'learning_rate': 3.12e-05, 'epoch': 6.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.19583137944380333, 'eval_runtime': 10.68, 'eval_samples_per_second': 16.573, 'eval_steps_per_second': 1.124, 'epoch': 6.0}
{'loss': 0.1311, 'learning_rate': 3.6400000000000004e-05, 'epoch': 7.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.19418140165465042, 'eval_runtime': 10.7401, 'eval_samples_per_second': 16.48, 'eval_steps_per_second': 1.117, 'epoch': 7.0}
{'loss': 0.1164, 'learning_rate': 4.16e-05, 'epoch': 8.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.19473162427852322, 'eval_runtime': 11.5217, 'eval_samples_per_second': 15.362, 'eval_steps_per_second': 1.042, 'epoch': 8.0}
{'loss': 0.1019, 'learning_rate': 4.6800000000000006e-05, 'epoch': 9.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.19021321644743946, 'eval_runtime': 12.0298, 'eval_samples_per_second': 14.713, 'eval_steps_per_second': 0.998, 'epoch': 9.0}
{'loss': 0.0895, 'learning_rate': 4.814814814814815e-05, 'epoch': 10.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.20094610330683907, 'eval_runtime': 11.4055, 'eval_samples_per_second': 15.519, 'eval_steps_per_second': 1.052, 'epoch': 10.0}
{'loss': 0.0763, 'learning_rate': 4.3333333333333334e-05, 'epoch': 11.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.19925854332582454, 'eval_runtime': 11.9099, 'eval_samples_per_second': 14.862, 'eval_steps_per_second': 1.008, 'epoch': 11.0}
{'loss': 0.0634, 'learning_rate': 3.851851851851852e-05, 'epoch': 12.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.20558025992731876, 'eval_runtime': 10.8227, 'eval_samples_per_second': 16.354, 'eval_steps_per_second': 1.109, 'epoch': 12.0}
{'loss': 0.0552, 'learning_rate': 3.3703703703703706e-05, 'epoch': 13.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.20942077404168047, 'eval_runtime': 10.919, 'eval_samples_per_second': 16.21, 'eval_steps_per_second': 1.099, 'epoch': 13.0}
{'loss': 0.0475, 'learning_rate': 2.8888888888888888e-05, 'epoch': 14.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.1962740024392706, 'eval_runtime': 11.0443, 'eval_samples_per_second': 16.026, 'eval_steps_per_second': 1.087, 'epoch': 14.0}
{'loss': 0.0417, 'learning_rate': 2.4074074074074074e-05, 'epoch': 15.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.19395701286341374, 'eval_runtime': 10.9303, 'eval_samples_per_second': 16.193, 'eval_steps_per_second': 1.098, 'epoch': 15.0}
{'loss': 0.0368, 'learning_rate': 1.925925925925926e-05, 'epoch': 16.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.20125444023630482, 'eval_runtime': 11.0403, 'eval_samples_per_second': 16.032, 'eval_steps_per_second': 1.087, 'epoch': 16.0}
{'loss': 0.0339, 'learning_rate': 1.4444444444444444e-05, 'epoch': 17.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.19909817266995616, 'eval_runtime': 10.9534, 'eval_samples_per_second': 16.159, 'eval_steps_per_second': 1.096, 'epoch': 17.0}
{'loss': 0.0322, 'learning_rate': 9.62962962962963e-06, 'epoch': 18.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.20225631147367992, 'eval_runtime': 11.109, 'eval_samples_per_second': 15.933, 'eval_steps_per_second': 1.08, 'epoch': 18.0}
{'loss': 0.0301, 'learning_rate': 4.814814814814815e-06, 'epoch': 19.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.20340080154446336, 'eval_runtime': 12.1752, 'eval_samples_per_second': 14.538, 'eval_steps_per_second': 0.986, 'epoch': 19.0}
{'loss': 0.0296, 'learning_rate': 0.0, 'epoch': 20.0}


  0%|          | 0/12 [00:00<?, ?it/s]

{'eval_loss': 0.20372545840140355, 'eval_runtime': 11.6295, 'eval_samples_per_second': 15.22, 'eval_steps_per_second': 1.032, 'epoch': 20.0}
{'train_runtime': 1251.0258, 'train_samples_per_second': 13.189, 'train_steps_per_second': 0.831, 'train_loss': 0.1465141924527975, 'epoch': 20.0}


In [12]:
bert_multilabel = BertMultiLabel(
    './development/weights/Bert-Multilabel-2/checkpoint-780/'
)
bert_multilabel.to_gpu()

In [14]:
cls_report = bert_multilabel.evaluate(relx_training_data['train'])
print(cls_report)

              precision    recall  f1-score   support

           0       1.00      0.92      0.96        13
           1       1.00      0.97      0.99        73
           2       1.00      0.99      1.00       503
           3       1.00      1.00      1.00        18
           4       1.00      1.00      1.00        46
           5       1.00      0.99      0.99        92
           6       1.00      0.97      0.99        39
           7       1.00      0.95      0.97        19
           8       1.00      0.94      0.97        49
           9       1.00      0.99      1.00       157
          10       0.99      1.00      0.99        88
          11       1.00      0.95      0.97        56
          12       1.00      0.99      0.99       202
          13       0.98      0.96      0.97        51
          14       1.00      0.99      0.99        72
          15       1.00      0.88      0.93        24

   micro avg       1.00      0.98      0.99      1502
   macro avg       1.00   

In [13]:
cls_report = bert_multilabel.evaluate(relx_training_data['valid'])
print(cls_report)

              precision    recall  f1-score   support

           0       1.00      0.00      0.00         1
           1       0.64      0.60      0.62        15
           2       0.90      0.86      0.88       116
           3       0.67      0.25      0.36         8
           4       1.00      0.60      0.75        15
           5       0.70      0.80      0.74        20
           6       0.30      0.50      0.37         6
           7       1.00      0.00      0.00         4
           8       0.40      0.17      0.24        12
           9       0.59      0.34      0.43        29
          10       0.92      0.57      0.71        21
          11       0.67      0.29      0.40        14
          12       0.87      0.66      0.75        50
          13       0.71      0.71      0.71         7
          14       0.67      0.59      0.62        17
          15       1.00      0.40      0.57         5

   micro avg       0.79      0.64      0.71       340
   macro avg       0.75   

In [11]:
cls_report = bert_multilabel.evaluate(relx_training_data['test'])
print(cls_report)

              precision    recall  f1-score   support

           0       1.00      0.00      0.00         2
           1       0.69      0.69      0.69        16
           2       0.97      0.86      0.91       122
           3       1.00      0.00      0.00         8
           4       0.67      0.55      0.60        11
           5       0.87      0.87      0.87        15
           6       1.00      0.64      0.78        14
           7       1.00      0.00      0.00         3
           8       0.17      0.33      0.22         6
           9       0.77      0.62      0.69        32
          10       0.54      0.41      0.47        17
          11       0.54      0.54      0.54        13
          12       0.86      0.63      0.73        51
          13       0.90      0.82      0.86        11
          14       0.50      0.54      0.52        13
          15       1.00      0.10      0.18        10

   micro avg       0.81      0.67      0.73       344
   macro avg       0.78   

In [16]:
cls_report = bert_multilabel.evaluate_single_label(
    osdg_data['test'],
    mode='exact_match'
)
print(cls_report)

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


              precision    recall  f1-score   support

           0       0.94      0.48      0.63       410
           1       0.90      0.81      0.85       369
           2       0.54      0.95      0.69       404
           3       0.87      0.88      0.87       561
           4       0.93      0.79      0.85       651
           5       0.89      0.86      0.87       423
           6       0.94      0.78      0.86       457
           7       0.72      0.60      0.65       226
           8       0.81      0.80      0.80       316
           9       0.39      0.86      0.53       305
          10       0.72      0.91      0.80       341
          11       0.89      0.56      0.69       166
          12       0.66      0.88      0.76       315
          13       0.93      0.87      0.90       171
          14       0.90      0.83      0.86       321
          15       1.00      0.62      0.77       818

    accuracy                           0.78      6254
   macro avg       0.81   

In [17]:
accuracy = bert_multilabel.evaluate_single_label(
    osdg_data['test'],
    mode='included'
)
accuracy

0.8535337384074193

In [15]:
ucl_data = load_uclmodules_data(
    dev_config['uclmodules_data_path'],
    only_labled=True,
    evaluation=True
)

In [17]:
cls_report = bert_multilabel.evaluate(ucl_data)
print(cls_report)

              precision    recall  f1-score   support

           0       1.00      0.08      0.15        87
           1       0.38      0.50      0.43        12
           2       0.71      0.99      0.83      1015
           3       0.17      0.72      0.28       258
           4       0.56      0.62      0.59       128
           5       0.34      0.78      0.47        18
           6       0.76      0.51      0.61        76
           7       0.63      0.37      0.47       339
           8       0.69      0.42      0.52       588
           9       0.07      0.74      0.13        53
          10       0.60      0.47      0.52       506
          11       0.10      0.25      0.14        24
          12       0.13      0.96      0.23        68
          13       0.53      0.40      0.46        20
          14       0.46      0.61      0.52        18
          15       0.70      0.57      0.63       667

   micro avg       0.46      0.63      0.53      3877
   macro avg       0.49   

In [118]:
ucl_data[0][50]

'Advanced Field Techniques (ARCL0032) \n This module is aimed to develop students knowledge of field techniques learnt during their first year and vacation fieldwork. Topics will include research designs, aerial photography, regional sampling, geophysics, site formation and transformation, digital site data recording, recording systems and post-excavation analysis.  A number of sessions are taught by visiting speakers as well as a number of IoA staff.\nFurther information is available here https://www.ucl.ac.uk/archaeology/study/undergraduate/courses/advanced-field-techniques\n '

In [117]:
ucl_data[1][50]

[8]

'Advanced Field Techniques (ARCL0032) \n This module is aimed to develop students knowledge of field techniques learnt during their first year and vacation fieldwork. Topics will include research designs, aerial photography, regional sampling, geophysics, site formation and transformation, digital site data recording, recording systems and post-excavation analysis.  A number of sessions are taught by visiting speakers as well as a number of IoA staff.\nFurther information is available here https://www.ucl.ac.uk/archaeology/study/undergraduate/courses/advanced-field-techniques\n '


In [124]:
predictions = bert_multilabel.predict(ucl_data[0][50])
predictions = bert_multilabel.parse_predictions(predictions, top_k=16, threshold=0.7)
predictions

[[]]