In [None]:
!pip install flair

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
label_type = 'ner'

In [None]:
columns={0:'text',1:'pos',2:'ner'}

In [None]:
datafolder='/content/drive/MyDrive/Subinay/7_tags/'
from flair.datasets import ColumnCorpus
from flair.data import Corpus

In [None]:
corpus: Corpus = ColumnCorpus(datafolder, columns,
                              train_file='tag_sequence_5_riot_10_evidence_new_flair.txt',test_file='tag_sequence_test_7tags_flair.txt',dev_file='tag_sequence_dev_8tags_flair.txt'
                  )

2023-06-24 10:04:21,847 Reading data from /content/drive/MyDrive/Subinay/7_tags
2023-06-24 10:04:21,849 Train: /content/drive/MyDrive/Subinay/7_tags/tag_sequence_5_riot_10_evidence_new_flair.txt
2023-06-24 10:04:21,850 Dev: /content/drive/MyDrive/Subinay/7_tags/tag_sequence_dev_8tags_flair.txt
2023-06-24 10:04:21,851 Test: /content/drive/MyDrive/Subinay/7_tags/tag_sequence_test_7tags_flair.txt


In [None]:
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)

2023-06-24 11:09:16,401 Computing label dictionary. Progress:


246it [00:00, 3663.69it/s]

2023-06-24 11:09:16,480 Dictionary created for label 'ner' with 8 values: Otag (seen 3330 times), wittest (seen 1130 times), riot (seen 915 times), evidence (seen 762 times), assault (seen 734 times), imprisonment (seen 509 times), expwittest (seen 502 times), homicide (seen 501 times)





In [None]:
print(label_dict)

Dictionary with 8 tags: Otag, wittest, riot, evidence, assault, imprisonment, expwittest, homicide


In [None]:
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )

In [None]:
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ner',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

2023-06-24 11:09:51,562 SequenceTagger predicts: Dictionary with 8 tags: Otag, wittest, riot, evidence, assault, imprisonment, expwittest, homicide


In [None]:
trainer = ModelTrainer(tagger, corpus)

In [None]:
trainer.fine_tune('/content/drive/MyDrive/Subinay/7_tags_gpt3.5_new/',
                  learning_rate=5.0e-6,
                  mini_batch_size=3,
                  #mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU
                  )

2023-06-24 11:21:49,834 ----------------------------------------------------------------------------------------------------
2023-06-24 11:21:49,839 Model: "SequenceTagger(
  (embeddings): TransformerWordEmbeddings(
    (model): XLMRobertaModel(
      (embeddings): XLMRobertaEmbeddings(
        (word_embeddings): Embedding(250003, 1024)
        (position_embeddings): Embedding(514, 1024, padding_idx=1)
        (token_type_embeddings): Embedding(1, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): XLMRobertaEncoder(
        (layer): ModuleList(
          (0-23): 24 x XLMRobertaLayer(
            (attention): XLMRobertaAttention(
              (self): XLMRobertaSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out

100%|██████████| 4/4 [00:00<00:00,  5.65it/s]

2023-06-24 11:22:39,930 Evaluating as a multi-label problem: False
2023-06-24 11:22:39,945 DEV : loss 1.2344506978988647 - f1-score (micro avg)  0.423
2023-06-24 11:22:39,951 ----------------------------------------------------------------------------------------------------





2023-06-24 11:22:44,542 epoch 2 - iter 8/82 - loss 0.42431572 - time (sec): 4.59 - samples/sec: 151.73 - lr: 0.000005
2023-06-24 11:22:49,823 epoch 2 - iter 16/82 - loss 0.59176144 - time (sec): 9.87 - samples/sec: 179.07 - lr: 0.000005
2023-06-24 11:22:54,124 epoch 2 - iter 24/82 - loss 0.64928036 - time (sec): 14.17 - samples/sec: 181.25 - lr: 0.000005
2023-06-24 11:22:58,647 epoch 2 - iter 32/82 - loss 0.60031831 - time (sec): 18.69 - samples/sec: 188.11 - lr: 0.000005
2023-06-24 11:23:03,802 epoch 2 - iter 40/82 - loss 0.58203047 - time (sec): 23.85 - samples/sec: 180.32 - lr: 0.000005
2023-06-24 11:23:08,686 epoch 2 - iter 48/82 - loss 0.53469041 - time (sec): 28.73 - samples/sec: 180.61 - lr: 0.000005
2023-06-24 11:23:13,360 epoch 2 - iter 56/82 - loss 0.53398848 - time (sec): 33.40 - samples/sec: 178.00 - lr: 0.000005
2023-06-24 11:23:17,521 epoch 2 - iter 64/82 - loss 0.52636568 - time (sec): 37.57 - samples/sec: 176.44 - lr: 0.000005
2023-06-24 11:23:21,815 epoch 2 - iter 72/8

100%|██████████| 4/4 [00:00<00:00,  5.62it/s]

2023-06-24 11:23:28,291 Evaluating as a multi-label problem: False
2023-06-24 11:23:28,306 DEV : loss 1.1192089319229126 - f1-score (micro avg)  0.6586
2023-06-24 11:23:28,314 ----------------------------------------------------------------------------------------------------





2023-06-24 11:23:33,101 epoch 3 - iter 8/82 - loss 0.57385772 - time (sec): 4.78 - samples/sec: 170.83 - lr: 0.000004
2023-06-24 11:23:37,534 epoch 3 - iter 16/82 - loss 0.60074043 - time (sec): 9.22 - samples/sec: 158.53 - lr: 0.000004
2023-06-24 11:23:42,457 epoch 3 - iter 24/82 - loss 0.57062827 - time (sec): 14.14 - samples/sec: 158.43 - lr: 0.000004
2023-06-24 11:23:47,097 epoch 3 - iter 32/82 - loss 0.57434144 - time (sec): 18.78 - samples/sec: 166.04 - lr: 0.000004
2023-06-24 11:23:51,801 epoch 3 - iter 40/82 - loss 0.54254948 - time (sec): 23.48 - samples/sec: 161.05 - lr: 0.000004
2023-06-24 11:23:56,641 epoch 3 - iter 48/82 - loss 0.51225732 - time (sec): 28.32 - samples/sec: 163.72 - lr: 0.000004
2023-06-24 11:24:01,787 epoch 3 - iter 56/82 - loss 0.48919101 - time (sec): 33.47 - samples/sec: 167.83 - lr: 0.000004
2023-06-24 11:24:06,639 epoch 3 - iter 64/82 - loss 0.48250107 - time (sec): 38.32 - samples/sec: 165.86 - lr: 0.000004
2023-06-24 11:24:11,254 epoch 3 - iter 72/8

100%|██████████| 4/4 [00:00<00:00,  5.64it/s]

2023-06-24 11:24:17,952 Evaluating as a multi-label problem: False
2023-06-24 11:24:17,966 DEV : loss 1.6205319166183472 - f1-score (micro avg)  0.6888
2023-06-24 11:24:17,973 ----------------------------------------------------------------------------------------------------





2023-06-24 11:24:22,271 epoch 4 - iter 8/82 - loss 0.24134932 - time (sec): 4.29 - samples/sec: 202.57 - lr: 0.000004
2023-06-24 11:24:27,082 epoch 4 - iter 16/82 - loss 0.33422091 - time (sec): 9.11 - samples/sec: 175.26 - lr: 0.000004
2023-06-24 11:24:31,627 epoch 4 - iter 24/82 - loss 0.31107386 - time (sec): 13.65 - samples/sec: 180.65 - lr: 0.000004
2023-06-24 11:24:36,694 epoch 4 - iter 32/82 - loss 0.37380907 - time (sec): 18.72 - samples/sec: 182.61 - lr: 0.000004
2023-06-24 11:24:41,088 epoch 4 - iter 40/82 - loss 0.40038743 - time (sec): 23.11 - samples/sec: 177.05 - lr: 0.000004
2023-06-24 11:24:45,990 epoch 4 - iter 48/82 - loss 0.38745946 - time (sec): 28.01 - samples/sec: 176.56 - lr: 0.000004
2023-06-24 11:24:50,601 epoch 4 - iter 56/82 - loss 0.38347205 - time (sec): 32.62 - samples/sec: 175.97 - lr: 0.000004
2023-06-24 11:24:55,993 epoch 4 - iter 64/82 - loss 0.39224655 - time (sec): 38.02 - samples/sec: 177.92 - lr: 0.000003
2023-06-24 11:25:00,876 epoch 4 - iter 72/8

100%|██████████| 4/4 [00:00<00:00,  5.70it/s]

2023-06-24 11:25:07,594 Evaluating as a multi-label problem: False
2023-06-24 11:25:07,614 DEV : loss 1.3565596342086792 - f1-score (micro avg)  0.6828
2023-06-24 11:25:07,622 ----------------------------------------------------------------------------------------------------





2023-06-24 11:25:11,962 epoch 5 - iter 8/82 - loss 0.42449676 - time (sec): 4.34 - samples/sec: 187.65 - lr: 0.000003
2023-06-24 11:25:16,708 epoch 5 - iter 16/82 - loss 0.35859105 - time (sec): 9.08 - samples/sec: 180.86 - lr: 0.000003
2023-06-24 11:25:21,508 epoch 5 - iter 24/82 - loss 0.32875017 - time (sec): 13.88 - samples/sec: 184.46 - lr: 0.000003
2023-06-24 11:25:26,357 epoch 5 - iter 32/82 - loss 0.37966673 - time (sec): 18.73 - samples/sec: 177.60 - lr: 0.000003
2023-06-24 11:25:30,895 epoch 5 - iter 40/82 - loss 0.42482600 - time (sec): 23.27 - samples/sec: 177.60 - lr: 0.000003
2023-06-24 11:25:34,486 epoch 5 - iter 48/82 - loss 0.40084004 - time (sec): 26.86 - samples/sec: 173.74 - lr: 0.000003
2023-06-24 11:25:39,461 epoch 5 - iter 56/82 - loss 0.38521037 - time (sec): 31.84 - samples/sec: 175.99 - lr: 0.000003
2023-06-24 11:25:43,659 epoch 5 - iter 64/82 - loss 0.41024314 - time (sec): 36.04 - samples/sec: 180.27 - lr: 0.000003
2023-06-24 11:25:48,416 epoch 5 - iter 72/8

100%|██████████| 4/4 [00:00<00:00,  5.69it/s]

2023-06-24 11:25:55,494 Evaluating as a multi-label problem: False
2023-06-24 11:25:55,514 DEV : loss 1.3751381635665894 - f1-score (micro avg)  0.7372
2023-06-24 11:25:55,521 ----------------------------------------------------------------------------------------------------





2023-06-24 11:26:00,219 epoch 6 - iter 8/82 - loss 0.47015832 - time (sec): 4.70 - samples/sec: 208.46 - lr: 0.000003
2023-06-24 11:26:05,024 epoch 6 - iter 16/82 - loss 0.47315920 - time (sec): 9.50 - samples/sec: 180.82 - lr: 0.000003
2023-06-24 11:26:09,773 epoch 6 - iter 24/82 - loss 0.42195996 - time (sec): 14.25 - samples/sec: 179.86 - lr: 0.000003
2023-06-24 11:26:14,493 epoch 6 - iter 32/82 - loss 0.43107521 - time (sec): 18.97 - samples/sec: 170.90 - lr: 0.000003
2023-06-24 11:26:19,026 epoch 6 - iter 40/82 - loss 0.43304868 - time (sec): 23.50 - samples/sec: 170.10 - lr: 0.000003
2023-06-24 11:26:23,580 epoch 6 - iter 48/82 - loss 0.42404800 - time (sec): 28.06 - samples/sec: 175.93 - lr: 0.000002
2023-06-24 11:26:28,310 epoch 6 - iter 56/82 - loss 0.42603071 - time (sec): 32.79 - samples/sec: 172.45 - lr: 0.000002
2023-06-24 11:26:32,795 epoch 6 - iter 64/82 - loss 0.40096056 - time (sec): 37.27 - samples/sec: 171.84 - lr: 0.000002
2023-06-24 11:26:37,301 epoch 6 - iter 72/8

100%|██████████| 4/4 [00:00<00:00,  5.72it/s]

2023-06-24 11:26:43,955 Evaluating as a multi-label problem: False
2023-06-24 11:26:43,978 DEV : loss 1.8125778436660767 - f1-score (micro avg)  0.6918
2023-06-24 11:26:43,986 ----------------------------------------------------------------------------------------------------





2023-06-24 11:26:48,771 epoch 7 - iter 8/82 - loss 0.46568279 - time (sec): 4.78 - samples/sec: 186.90 - lr: 0.000002
2023-06-24 11:26:53,794 epoch 7 - iter 16/82 - loss 0.31661363 - time (sec): 9.81 - samples/sec: 185.91 - lr: 0.000002
2023-06-24 11:26:58,278 epoch 7 - iter 24/82 - loss 0.40072259 - time (sec): 14.29 - samples/sec: 171.24 - lr: 0.000002
2023-06-24 11:27:02,861 epoch 7 - iter 32/82 - loss 0.39516716 - time (sec): 18.87 - samples/sec: 170.77 - lr: 0.000002
2023-06-24 11:27:07,357 epoch 7 - iter 40/82 - loss 0.36014481 - time (sec): 23.37 - samples/sec: 172.15 - lr: 0.000002
2023-06-24 11:27:12,082 epoch 7 - iter 48/82 - loss 0.34976207 - time (sec): 28.09 - samples/sec: 171.50 - lr: 0.000002
2023-06-24 11:27:16,722 epoch 7 - iter 56/82 - loss 0.35343329 - time (sec): 32.73 - samples/sec: 177.92 - lr: 0.000002
2023-06-24 11:27:21,504 epoch 7 - iter 64/82 - loss 0.33006572 - time (sec): 37.52 - samples/sec: 177.28 - lr: 0.000002
2023-06-24 11:27:25,892 epoch 7 - iter 72/8

100%|██████████| 4/4 [00:00<00:00,  5.61it/s]

2023-06-24 11:27:32,426 Evaluating as a multi-label problem: False
2023-06-24 11:27:32,451 DEV : loss 1.5548163652420044 - f1-score (micro avg)  0.6918
2023-06-24 11:27:32,458 ----------------------------------------------------------------------------------------------------





2023-06-24 11:27:36,932 epoch 8 - iter 8/82 - loss 0.20787724 - time (sec): 4.47 - samples/sec: 189.86 - lr: 0.000002
2023-06-24 11:27:41,850 epoch 8 - iter 16/82 - loss 0.28176048 - time (sec): 9.39 - samples/sec: 182.75 - lr: 0.000002
2023-06-24 11:27:46,598 epoch 8 - iter 24/82 - loss 0.23519999 - time (sec): 14.14 - samples/sec: 189.28 - lr: 0.000002
2023-06-24 11:27:51,681 epoch 8 - iter 32/82 - loss 0.26665447 - time (sec): 19.22 - samples/sec: 186.05 - lr: 0.000001
2023-06-24 11:27:56,254 epoch 8 - iter 40/82 - loss 0.29523344 - time (sec): 23.79 - samples/sec: 179.72 - lr: 0.000001
2023-06-24 11:28:00,803 epoch 8 - iter 48/82 - loss 0.28052654 - time (sec): 28.34 - samples/sec: 174.90 - lr: 0.000001
2023-06-24 11:28:05,382 epoch 8 - iter 56/82 - loss 0.27625906 - time (sec): 32.92 - samples/sec: 169.16 - lr: 0.000001
2023-06-24 11:28:10,257 epoch 8 - iter 64/82 - loss 0.27757397 - time (sec): 37.80 - samples/sec: 171.89 - lr: 0.000001
2023-06-24 11:28:15,116 epoch 8 - iter 72/8

100%|██████████| 4/4 [00:00<00:00,  5.72it/s]

2023-06-24 11:28:22,239 Evaluating as a multi-label problem: False
2023-06-24 11:28:22,255 DEV : loss 1.5753391981124878 - f1-score (micro avg)  0.7039
2023-06-24 11:28:22,262 ----------------------------------------------------------------------------------------------------





2023-06-24 11:28:26,234 epoch 9 - iter 8/82 - loss 0.44948017 - time (sec): 3.97 - samples/sec: 182.37 - lr: 0.000001
2023-06-24 11:28:30,787 epoch 9 - iter 16/82 - loss 0.41979212 - time (sec): 8.52 - samples/sec: 180.32 - lr: 0.000001
2023-06-24 11:28:35,308 epoch 9 - iter 24/82 - loss 0.34782642 - time (sec): 13.04 - samples/sec: 180.38 - lr: 0.000001
2023-06-24 11:28:39,893 epoch 9 - iter 32/82 - loss 0.30266348 - time (sec): 17.63 - samples/sec: 171.31 - lr: 0.000001
2023-06-24 11:28:44,779 epoch 9 - iter 40/82 - loss 0.32878151 - time (sec): 22.52 - samples/sec: 173.52 - lr: 0.000001
2023-06-24 11:28:49,526 epoch 9 - iter 48/82 - loss 0.30352130 - time (sec): 27.26 - samples/sec: 167.04 - lr: 0.000001
2023-06-24 11:28:55,155 epoch 9 - iter 56/82 - loss 0.29435662 - time (sec): 32.89 - samples/sec: 167.52 - lr: 0.000001
2023-06-24 11:29:00,000 epoch 9 - iter 64/82 - loss 0.30106267 - time (sec): 37.74 - samples/sec: 170.02 - lr: 0.000001
2023-06-24 11:29:04,532 epoch 9 - iter 72/8

100%|██████████| 4/4 [00:00<00:00,  5.69it/s]

2023-06-24 11:29:11,055 Evaluating as a multi-label problem: False
2023-06-24 11:29:11,069 DEV : loss 1.7115142345428467 - f1-score (micro avg)  0.6888
2023-06-24 11:29:11,074 ----------------------------------------------------------------------------------------------------





2023-06-24 11:29:15,046 epoch 10 - iter 8/82 - loss 0.29426885 - time (sec): 3.97 - samples/sec: 156.41 - lr: 0.000001
2023-06-24 11:29:20,302 epoch 10 - iter 16/82 - loss 0.21123209 - time (sec): 9.23 - samples/sec: 200.29 - lr: 0.000000
2023-06-24 11:29:25,359 epoch 10 - iter 24/82 - loss 0.23413686 - time (sec): 14.28 - samples/sec: 195.96 - lr: 0.000000
2023-06-24 11:29:29,869 epoch 10 - iter 32/82 - loss 0.25713329 - time (sec): 18.79 - samples/sec: 187.14 - lr: 0.000000
2023-06-24 11:29:34,298 epoch 10 - iter 40/82 - loss 0.22479840 - time (sec): 23.22 - samples/sec: 190.93 - lr: 0.000000
2023-06-24 11:29:38,562 epoch 10 - iter 48/82 - loss 0.24837954 - time (sec): 27.49 - samples/sec: 183.98 - lr: 0.000000
2023-06-24 11:29:43,704 epoch 10 - iter 56/82 - loss 0.24844737 - time (sec): 32.63 - samples/sec: 180.36 - lr: 0.000000
2023-06-24 11:29:48,143 epoch 10 - iter 64/82 - loss 0.24235509 - time (sec): 37.07 - samples/sec: 178.35 - lr: 0.000000
2023-06-24 11:29:52,796 epoch 10 - 

100%|██████████| 4/4 [00:00<00:00,  5.71it/s]

2023-06-24 11:29:59,063 Evaluating as a multi-label problem: False
2023-06-24 11:29:59,080 DEV : loss 1.5872673988342285 - f1-score (micro avg)  0.6707





2023-06-24 11:30:12,056 ----------------------------------------------------------------------------------------------------
2023-06-24 11:30:12,063 Testing using last state of model ...


100%|██████████| 246/246 [00:39<00:00,  6.29it/s]


2023-06-24 11:30:51,239 Evaluating as a multi-label problem: False
2023-06-24 11:30:51,389 0.7042	0.7042	0.7042	0.7042
2023-06-24 11:30:51,393 
Results:
- F-score (micro) 0.7042
- F-score (macro) 0.5399
- Accuracy 0.7042

By class:
              precision    recall  f1-score   support

  expwittest     0.9576    0.7535    0.8434      7789
        Otag     0.6737    0.7904    0.7274      4918
     wittest     0.6589    0.6735    0.6662      3841
     assault     0.3339    0.8026    0.4716       755
    homicide     0.6241    0.0801    0.1420      1036
    evidence     0.1887    0.6381    0.2913       257
        riot     0.7561    0.6049    0.6721       410
imprisonment     0.4816    0.5320    0.5055       344

    accuracy                         0.7042     19350
   macro avg     0.5843    0.6094    0.5399     19350
weighted avg     0.7610    0.7042    0.7097     19350

2023-06-24 11:30:51,395 ---------------------------------------------------------------------------------------------

{'test_score': 0.7042377260981912,
 'dev_score_history': [0.4229607250755287,
  0.6586102719033232,
  0.6888217522658611,
  0.6827794561933535,
  0.7371601208459214,
  0.6918429003021148,
  0.6918429003021148,
  0.7039274924471299,
  0.6888217522658611,
  0.6706948640483383],
 'train_loss_history': [0.5559685698958239,
  0.5286811039724513,
  0.44707634105796706,
  0.39176827357006744,
  0.3942988349570731,
  0.3659606898625881,
  0.33259668719984675,
  0.273283056970799,
  0.29694189396469656,
  0.2561686593125099],
 'dev_loss_history': [1.2344506978988647,
  1.1192089319229126,
  1.6205319166183472,
  1.3565596342086792,
  1.3751381635665894,
  1.8125778436660767,
  1.5548163652420044,
  1.5753391981124878,
  1.7115142345428467,
  1.5872673988342285]}

In [None]:
model = SequenceTagger.load("/content/drive/MyDrive/Subinay/tagger_gpt3.5_5_riot_10_evidence/final-model.pt")

2023-06-17 11:45:15,536 SequenceTagger predicts: Dictionary with 8 tags: Otag, riot, evidence, assault, wittest, imprisonment, expwittest, homicide
