# Event Detection using Prompt Learning
 **Author : Rohan Salvi** <br>
**NetID: rcsalvi2**

## Libraries and Functions

Installing OpenPrompt Library

In [None]:
!pip install openprompt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting openprompt
  Downloading openprompt-1.0.1-py3-none-any.whl (146 kB)
[K     |████████████████████████████████| 146 kB 33.7 MB/s 
[?25hCollecting sentencepiece==0.1.96
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 55.8 MB/s 
Collecting rouge==1.0.0
  Downloading rouge-1.0.0-py3-none-any.whl (14 kB)
Collecting transformers>=4.10.0
  Downloading transformers-4.22.2-py3-none-any.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 57.3 MB/s 
[?25hCollecting yacs
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting tensorboardX
  Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)
[K     |████████████████████████████████| 125 kB 73.2 MB/s 
[?25hCollecting datasets
  Downloading datasets-2.5.2-py3-none-any.whl (432 kB)
[K     |█████████████████████

Preprocessing, Evaluation and Utility Functions

In [None]:
import json
import re
import tqdm
import numpy as np

from openprompt.data_utils import InputExample, InputFeatures
from torch.utils.data._utils.collate import default_collate
from openprompt.plms import load_plm


def save_div(a, b):
    if b != 0:
        return a / b
    else:
        return 0.0


def evaluation(gold_labels, pred_labels, vocab):
    inv_vocab = {v:k for k,v in vocab.items()}
    result = {}
    for label, idx in vocab.items():
        if idx != 0:
            result[label] = {"prec": 0.0, "rec": 0.0, "f1": 0.0}

    total_pred_num, total_gold_num, total_correct_num = 0.0, 0.0, 0.0

    for i in range(len(gold_labels)):
        pred_labels_i = pred_labels[i]
        gold_labels_i = gold_labels[i]

        for idx in gold_labels_i:
            if idx != 0:
                total_gold_num += 1
                result[inv_vocab[idx]]["rec"] += 1

        for idx in pred_labels_i:
            if idx != 0:
                total_pred_num += 1
                result[inv_vocab[idx]]["prec"] += 1

                if idx in gold_labels_i:
                    total_correct_num += 1
                    result[inv_vocab[idx]]["f1"] += 1

    for label in result:
        counts = result[label]
        counts["prec"] = save_div(counts["f1"], counts["prec"])
        counts["rec"] = save_div(counts["f1"], counts["rec"])
        counts["f1"] = save_div(2*counts["prec"]*counts["rec"], counts["prec"]+counts["rec"])

    prec = save_div(total_correct_num, total_pred_num)
    rec = save_div(total_correct_num, total_gold_num)
    f1 = save_div(2*prec*rec, prec+rec)

    return prec, rec, f1, result


def load_data(input_dir):
    data_items = []

    with open(input_dir, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    for line in lines:
        item = json.loads(line)
        data_items.append(item)
    return data_items


def get_vocab(train_dir, valid_dir):
    train_data = load_data(train_dir)
    valid_data = load_data(valid_dir)
    vocab = {"None": 0}

    for item in train_data:
        for event in item["events"]:
            if event[-1] not in vocab:
                vocab[event[-1]] = len(vocab)

    for item in valid_data:
        for event in item["events"]:
            if event[-1] not in vocab:
                vocab[event[-1]] = len(vocab)

    return vocab


def process_data(input_dir, vocab):
    items = load_data(input_dir)
    output_items = []

    for i,item in enumerate(items):
        labels = []

        for event in item["events"]:
            if vocab[event[-1]] not in labels:
                labels.append(vocab[event[-1]])

        if len(labels) == 0:
            labels.append(0)

        input_example = InputExample(text_a=" ".join(item["tokens"]), label=labels, guid=i)
        output_items.append(input_example)

    return output_items


def my_collate_fn(batch):
    elem = batch[0]
    return_dict = {}
    for key in elem:
        if key == "encoded_tgt_text":
            return_dict[key] = [d[key] for d in batch]
        else:
            try:
                return_dict[key] = default_collate([d[key] for d in batch])
            except:
                return_dict[key] = [d[key] for d in batch]

    return InputFeatures(**return_dict)


def convert_labels_to_list(labels):
    label_list = []
    for label in labels:
        label_list.append(label.tolist().copy())
    return label_list


def to_device(data, device):
    for key in ["input_ids", "attention_mask", "decoder_input_ids", "loss_ids"]:
        data[key] = data[key].to(device)
    return data


def get_plm():
    # You may change the PLM you'd like to use. Currently it's t5-base.
    return load_plm("t5", "t5-base")

def get_template():
    # You may design your own template here.
    return '{"placeholder":"text_a"} This text describes a {"mask"} event.'

def get_verbalizer(vocab):
    # Input: a dictionary for event types to indices: e.g.: {'None': 0, 'Catastrophe': 1, 'Causation': 2, 'Motion': 3, 'Hostile_encounter': 4, 'Process_start': 5, 'Attack': 6, 'Killing': 7, 'Conquering': 8, 'Social_event': 9, 'Competition': 10}

    # Output: A 2-dim list. Verbalizers for each event type. Currently this function directly returns the lowercase for each event type name (and the performance is low). You may want to design your own verbalizers to improve the performance.
    x = [[label.lower()] for label in vocab]
    #print(x)
    return [[label.lower()] for label in vocab]
    #return [['none'], ['catastrophe'], ['causation'], ['motion'], ['hostile'], ['start'], ['assault'], ['kill'], ['conquer'], ['celebration'], ['competition']]
    #return [['none'], ['catastrophe','disaster','calamity','crisis','tragedy'], ['causation','causing','causality','causal'], ['motion','movement','shift','transit'], ['hostile'], ['start','begin','commence','initiate'], ['attack','charge','strike','ambush','assault'], ['kill','murder','assasinate','death','slay','annihilate'], ['conquer'], ['social'], ['competition']]

def loss_func(logits, labels):

    # INPUT:
    ##  logits: a torch.Tensor of (batch_size, number_of_event_types) which is the output logits for each event type (none types are included).
    ##  labels: a 2-dim List which denotes the ground-truth labels for each sentence in the batch. Note that there cound be multiple events, a single event, or no events for each sentence.
    ##  For example, if labels == [[0], [2,3,4]] then the batch size is 2 and the first sentence has no events, and the second sentence has three events of indices 2,3 and 4.

    ##  INSTRUCTIONS: In general, we want to maximize the logits of correct labels and minimize the logits with incorrect labels. You can implement your own loss function here or you can refer to what loss function is used in https://arxiv.org/pdf/2202.07615.pdf

    ## OUTPUT:
    ##   The output should be a pytorch scalar --- the loss.

    ###  YOU NEED TO WRITE YOUR CODE HERE.  ###

    loss = 0.0

    for i in range(0,len(labels)):
      ## Will calculate positve loss:
      loss_positive = 0
      target_labels = labels[i]
      #print(target_labels)
      if target_labels[0] != 0:
        for j in range(0,len(target_labels)):
          pos_label = target_labels[j]
          logit = logits[i]

          loss_positive = loss_positive + torch.log((torch.exp(logit[j])/(torch.exp(logit[j])+torch.exp(logit[0]))))

        loss_positive = loss_positive / len(target_labels)

      ## Will calculate negative loss:
      event_label_list = [1,2,3,4,5,6,7,8,9,10]
      negative_event_label = list(set(event_label_list).difference(set(labels[i])))
      loss_neg = 0
      #print(negative_event_label)

      logit = logits[i]
      denom =  torch.exp(logit[0])
      for j in range(0,len(negative_event_label)):
        neg_label = negative_event_label[j]
        denom =  denom + torch.exp(logit[j])


      loss_neg = torch.log((torch.exp(logit[0])/(denom)))


      loss = loss + (loss_neg+loss_positive)

    loss = (loss / len(labels)) * (-1)

    return loss
    #print(logits.shape)
    #print(logits)
    #print(labels)
    #raise NotImplementedError
    #print(loss)



# def predict(logits):
#     # INPUT:
#     ##  logits: a torch.Tensor of (batch_size, number_of_event_types) which is the output logits for each event type (none types are included).
#     # OUTPUT:
#     ##  a 2-dim list which has the same format with the "labels" in "loss_func" --- the predictions for all the sentences in the batch.
#     ##  For example, if predictions == [[0], [2,3,4]] then the batch size is 2, and we predict no events for the first sentence and three events (2,3,and 4) for the second sentence.

#     ##  INSTRUCTIONS: The most straight-forward way for prediction is to select out the indices with maximum of logits. Note that this is a multi-label classification problem, so each sentence could have multiple predicted event indices. Using what threshold for prediction is important here. You can also use the None event (index 0) as the threshold as what https://arxiv.org/pdf/2202.07615.pdf does.

#     ###  YOU NEED TO WRITE YOUR CODE HERE.  ###

#     batch_size = list(logits.size())[0]
#     predictions = []
#     for i in range(0, batch_size):
#       # sent_events = []
#       logit = logits[i]

#       #print(torch.nn.functional.softmax(logit))
#       # denom = 0.0
#       # for j in range(0,11):
#       #   denom = denom + torch.exp(logit[j])
#       # for j in range(1,11):
#       #   if torch.exp(logit[0]) < torch.exp(logit[j]):
#       #     sent_events.append(j)

#       # if len(sent_events) == 0:
#       #   sent_events.append(0)

#       # predictions.append(sent_events)
#       val = torch.argmax(logit, dim=-1).cpu().tolist()
#       #print(val)
#       predictions.append([val])
#     #print(logits.shape)
#     #print(logits)
#     #print(predictions)
#     #raise NotImplementedError

#     return predictions








Loss Function

In [None]:
def sig_loss_func(logits, labels):
     loss = 0.0
     eps=1e-10
     for i in range(0,len(labels)):
      ## Will calculate positve loss:
      loss_ind = 0
      target_labels = labels[i]
      odds = torch.exp(logits[i])
      prob = odds / (odds + 1)
      #print(prob)
      #logit = logits[i]
      #print(target_labels)
      for j in range(0,11):
          if j in target_labels:
            #loss_ind = loss_ind + (torch.log(logit[j] + eps))
            loss_ind = loss_ind + (torch.log(prob[j] + eps))

          else:
            #loss_ind = loss_ind + (torch.log(torch.abs(1- logit[j]+ eps)))
            loss_ind = loss_ind + (torch.log(torch.abs(1- prob[j]+ eps)))

      loss = loss + loss_ind
      loss = loss / len(labels)
     return (-1)*loss

Prediction function

In [None]:
def predict(logits):
    # INPUT:
    ##  logits: a torch.Tensor of (batch_size, number_of_event_types) which is the output logits for each event type (none types are included).
    # OUTPUT:
    ##  a 2-dim list which has the same format with the "labels" in "loss_func" --- the predictions for all the sentences in the batch.
    ##  For example, if predictions == [[0], [2,3,4]] then the batch size is 2, and we predict no events for the first sentence and three events (2,3,and 4) for the second sentence.

    ##  INSTRUCTIONS: The most straight-forward way for prediction is to select out the indices with maximum of logits. Note that this is a multi-label classification problem, so each sentence could have multiple predicted event indices. Using what threshold for prediction is important here. You can also use the None event (index 0) as the threshold as what https://arxiv.org/pdf/2202.07615.pdf does.

    ###  YOU NEED TO WRITE YOUR CODE HERE.  ###

    batch_size = list(logits.size())[0]
    predictions = []
    for i in range(0, batch_size):
      # sent_events = []
      logit = logits[i]
      val = torch.argmax(logit, dim=-1).cpu().tolist()
      #print(val)
      predictions.append([val])


    return predictions

## Training & Validation

In [None]:
if __name__ == "__main__":
    train_dir = "/content/train.json"
    valid_dir = "/content/valid.json"
    test_dir = "/content/test.json"

    vocabulary = get_vocab(train_dir, valid_dir)
    dataset = {
        "train": process_data(train_dir, vocabulary),
        "validation": process_data(valid_dir, vocabulary),
        "test": process_data(test_dir, vocabulary)
    }
    print(vocabulary)
    inv_vocabulary = {v:k for k,v in vocabulary.items()}
    print(inv_vocabulary)

    from openprompt.prompts import ManualTemplate
    plm, tokenizer, model_config, WrapperClass = get_plm()

    template_text = get_template()
    mytemplate = ManualTemplate(tokenizer=tokenizer, text=template_text)

    from openprompt import PromptDataLoader
    train_dataloader = PromptDataLoader(
        dataset=dataset["train"],
        template=mytemplate,
        tokenizer=tokenizer,
        tokenizer_wrapper_class=WrapperClass,
        max_seq_length=256,
        decoder_max_length=3,
        batch_size=10,
        shuffle=True,
        teacher_forcing=False,
        predict_eos_token=False,
        truncate_method="head"
    )
    train_dataloader.dataloader.collate_fn = my_collate_fn

    validation_dataloader = PromptDataLoader(
        dataset=dataset["validation"],
        template=mytemplate,
        tokenizer=tokenizer,
        tokenizer_wrapper_class=WrapperClass,
        max_seq_length=256,
        decoder_max_length=3,
        batch_size=10,
        shuffle=False,
        teacher_forcing=False,
        predict_eos_token=False,
        truncate_method="head"
    )
    validation_dataloader.dataloader.collate_fn = my_collate_fn

    test_dataloader = PromptDataLoader(
        dataset=dataset["test"],
        template=mytemplate,
        tokenizer=tokenizer,
        tokenizer_wrapper_class=WrapperClass,
        max_seq_length=256,
        decoder_max_length=3,
        batch_size=10,
        shuffle=False,
        teacher_forcing=False,
        predict_eos_token=False,
        truncate_method="head"
    )
    test_dataloader.dataloader.collate_fn = my_collate_fn






{'None': 0, 'Catastrophe': 1, 'Causation': 2, 'Motion': 3, 'Hostile_encounter': 4, 'Process_start': 5, 'Attack': 6, 'Killing': 7, 'Conquering': 8, 'Social_event': 9, 'Competition': 10}
{0: 'None', 1: 'Catastrophe', 2: 'Causation', 3: 'Motion', 4: 'Hostile_encounter', 5: 'Process_start', 6: 'Attack', 7: 'Killing', 8: 'Conquering', 9: 'Social_event', 10: 'Competition'}


Downloading:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
tokenizing: 5378it [00:06, 842.71it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (734 > 512). Running this sequence through the model will result in indexing errors
tokenizing: 32431it [00:42, 767.27it/s]
tokenizing: 8042it [00:09, 859.71it/s]
tokenizing: 9400it [00:11, 842.46it/s]


In [None]:
from openprompt.prompts import ManualVerbalizer
import torch
from openprompt.prompts import SoftVerbalizer

label_words = get_verbalizer(vocabulary)

    # for example the verbalizer contains multiple label words in each class
myverbalizer = ManualVerbalizer(tokenizer,
    num_classes=len(vocabulary),
        label_words=label_words)
from openprompt import PromptForClassification
use_cuda = True
prompt_model = PromptForClassification(plm=plm, template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
  prompt_model = prompt_model.cuda()

from transformers import  AdamW, get_linear_schedule_with_warmup
no_decay = ['bias', 'LayerNorm.weight']
    # it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters = [
        {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"

max_f1 = 0.0
max_patience, current_patience = 2, 0
if_exit = False



In [None]:

for epoch in range(10):
        if if_exit:
            break
        tot_loss = 0.0
        progress = tqdm.tqdm(total=len(train_dataloader), ncols=150, desc="Epoch: "+str(epoch))
        for step, inputs in enumerate(train_dataloader):
            if if_exit:
                break
            if use_cuda:
                inputs = to_device(inputs, device)
            logits = prompt_model(inputs)
            #print('loop')
            #print(logits)
            labels = inputs['label']
            label_list = convert_labels_to_list(labels)
            loss = sig_loss_func(logits, label_list)
            # label_tensor = convert_label(label_list)
            # loss = cross_loss(logits, label_tensor)
            loss.backward()
            tot_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()

            if step %200 ==99:
                print("\nStep {}, average loss: {}".format(step, tot_loss/(step+1)), flush=True)

                allpreds, alllabels = [], []
                # Validation:
                valid_progress = tqdm.tqdm(total=len(validation_dataloader), ncols=150, desc="Validation: ")
                prompt_model.eval()
                with torch.no_grad():
                    for step, inputs in enumerate(validation_dataloader):
                        if use_cuda:
                            inputs = to_device(inputs, device)
                        logits = prompt_model(inputs)
                        labels = inputs['label']
                        label_list = convert_labels_to_list(labels)
                        pred_labels = predict(logits)
                        if step % 50 == 0:
                          print(pred_labels)
                          print(label_list)

                        alllabels.extend(label_list)
                        allpreds.extend(pred_labels)
                        valid_progress.update(1)

                valid_progress.close()
                prompt_model.train()

                p, r, f, total = evaluation(alllabels, allpreds, vocabulary)
                print("F1-Score: " + str(f))
                with open("results.json", 'w', encoding='utf-8') as f_out:
                    f_out.write(json.dumps(total, indent=4))
                if f > max_f1:
                    max_f1 = f
                    torch.save(prompt_model.state_dict(), "./checkpoint_best.pt")
                    current_patience = 0


                else:
                    current_patience += 1
                    if current_patience > max_patience:
                        if_exit = True


            progress.update(1)
        progress.close()

Epoch: 0:  28%|███████████████████████████▋                                                                        | 899/3244 [17:34<22:02,  1.77it/s]


Step 899, average loss: 0.28839744115869204



Validation:   0%|                                                                                                             | 0/805 [00:00<?, ?it/s][A
Validation:   0%|▏                                                                                                    | 1/805 [00:00<02:14,  5.96it/s][A

[[4], [8], [0], [6], [0], [0], [0], [6], [3], [6]]
[[4, 5], [8], [0], [6], [0], [0], [0], [0], [3], [6]]



Validation:   0%|▎                                                                                                    | 2/805 [00:00<02:20,  5.73it/s][A
Validation:   0%|▍                                                                                                    | 3/805 [00:00<02:20,  5.71it/s][A
Validation:   0%|▌                                                                                                    | 4/805 [00:00<02:20,  5.71it/s][A
Validation:   1%|▋                                                                                                    | 5/805 [00:00<02:17,  5.81it/s][A
Validation:   1%|▊                                                                                                    | 6/805 [00:01<02:17,  5.79it/s][A
Validation:   1%|▉                                                                                                    | 7/805 [00:01<02:15,  5.88it/s][A
Validation:   1%|█                                                         

[[0], [0], [0], [0], [0], [6], [0], [0], [0], [0]]
[[9], [9], [9], [0], [9], [5, 6], [0], [0], [0], [0]]



Validation:   7%|██████▌                                                                                             | 53/805 [00:09<02:06,  5.95it/s][A
Validation:   7%|██████▋                                                                                             | 54/805 [00:09<02:06,  5.95it/s][A
Validation:   7%|██████▊                                                                                             | 55/805 [00:09<02:05,  5.98it/s][A
Validation:   7%|██████▉                                                                                             | 56/805 [00:09<02:04,  6.00it/s][A
Validation:   7%|███████                                                                                             | 57/805 [00:09<02:04,  6.01it/s][A
Validation:   7%|███████▏                                                                                            | 58/805 [00:09<02:05,  5.94it/s][A
Validation:   7%|███████▎                                                  

[[0], [7], [1], [1], [0], [0], [1], [3], [1], [2]]
[[0], [7], [2], [0], [2], [0], [7], [3], [0], [0]]



Validation:  13%|████████████▋                                                                                      | 103/805 [00:17<02:00,  5.83it/s][A
Validation:  13%|████████████▊                                                                                      | 104/805 [00:17<01:59,  5.88it/s][A
Validation:  13%|████████████▉                                                                                      | 105/805 [00:17<01:58,  5.91it/s][A
Validation:  13%|█████████████                                                                                      | 106/805 [00:18<01:59,  5.84it/s][A
Validation:  13%|█████████████▏                                                                                     | 107/805 [00:18<01:58,  5.89it/s][A
Validation:  13%|█████████████▎                                                                                     | 108/805 [00:18<01:59,  5.84it/s][A
Validation:  14%|█████████████▍                                            

[[0], [0], [0], [0], [3], [0], [8], [0], [0], [0]]
[[0], [0], [0], [5], [2], [0], [8], [0], [0], [0]]



Validation:  19%|██████████████████▊                                                                                | 153/805 [00:26<01:51,  5.86it/s][A
Validation:  19%|██████████████████▉                                                                                | 154/805 [00:26<01:51,  5.84it/s][A
Validation:  19%|███████████████████                                                                                | 155/805 [00:26<01:51,  5.84it/s][A
Validation:  19%|███████████████████▏                                                                               | 156/805 [00:26<01:51,  5.84it/s][A
Validation:  20%|███████████████████▎                                                                               | 157/805 [00:26<01:52,  5.77it/s][A
Validation:  20%|███████████████████▍                                                                               | 158/805 [00:27<01:52,  5.75it/s][A
Validation:  20%|███████████████████▌                                      

[[6], [0], [0], [0], [0], [0], [0], [0], [4], [4]]
[[6, 7, 5, 8], [0], [0], [0], [0], [0], [0], [0], [4], [0]]



Validation:  25%|████████████████████████▉                                                                          | 203/805 [00:34<01:42,  5.88it/s][A
Validation:  25%|█████████████████████████                                                                          | 204/805 [00:34<01:41,  5.90it/s][A
Validation:  25%|█████████████████████████▏                                                                         | 205/805 [00:34<01:41,  5.94it/s][A
Validation:  26%|█████████████████████████▎                                                                         | 206/805 [00:35<01:41,  5.88it/s][A
Validation:  26%|█████████████████████████▍                                                                         | 207/805 [00:35<01:42,  5.84it/s][A
Validation:  26%|█████████████████████████▌                                                                         | 208/805 [00:35<01:42,  5.84it/s][A
Validation:  26%|█████████████████████████▋                                

[[0], [5], [0], [9], [5], [5], [0], [0], [0], [0]]
[[2], [0], [0], [0], [5], [5], [0], [0], [0], [0]]



Validation:  31%|███████████████████████████████                                                                    | 253/805 [00:43<01:33,  5.92it/s][A
Validation:  32%|███████████████████████████████▏                                                                   | 254/805 [00:43<01:32,  5.93it/s][A
Validation:  32%|███████████████████████████████▎                                                                   | 255/805 [00:43<01:32,  5.93it/s][A
Validation:  32%|███████████████████████████████▍                                                                   | 256/805 [00:43<01:32,  5.92it/s][A
Validation:  32%|███████████████████████████████▌                                                                   | 257/805 [00:43<01:32,  5.95it/s][A
Validation:  32%|███████████████████████████████▋                                                                   | 258/805 [00:44<01:32,  5.93it/s][A
Validation:  32%|███████████████████████████████▊                          

[[4], [0], [0], [0], [5], [5], [4], [6], [5], [0]]
[[1, 4], [0], [8], [0], [5], [5], [4], [0], [0], [0]]



Validation:  38%|█████████████████████████████████████▎                                                             | 303/805 [00:51<01:24,  5.94it/s][A
Validation:  38%|█████████████████████████████████████▍                                                             | 304/805 [00:51<01:24,  5.96it/s][A
Validation:  38%|█████████████████████████████████████▌                                                             | 305/805 [00:52<01:24,  5.90it/s][A
Validation:  38%|█████████████████████████████████████▋                                                             | 306/805 [00:52<01:24,  5.88it/s][A
Validation:  38%|█████████████████████████████████████▊                                                             | 307/805 [00:52<01:25,  5.85it/s][A
Validation:  38%|█████████████████████████████████████▉                                                             | 308/805 [00:52<01:24,  5.85it/s][A
Validation:  38%|██████████████████████████████████████                    

[[0], [0], [0], [0], [0], [0], [5], [0], [0], [0]]
[[0], [0], [0], [0], [0], [0], [5], [0], [0], [0]]



Validation:  44%|███████████████████████████████████████████▍                                                       | 353/805 [01:00<01:16,  5.89it/s][A
Validation:  44%|███████████████████████████████████████████▌                                                       | 354/805 [01:00<01:16,  5.88it/s][A
Validation:  44%|███████████████████████████████████████████▋                                                       | 355/805 [01:00<01:16,  5.87it/s][A
Validation:  44%|███████████████████████████████████████████▊                                                       | 356/805 [01:00<01:16,  5.85it/s][A
Validation:  44%|███████████████████████████████████████████▉                                                       | 357/805 [01:00<01:16,  5.83it/s][A
Validation:  44%|████████████████████████████████████████████                                                       | 358/805 [01:01<01:17,  5.79it/s][A
Validation:  45%|████████████████████████████████████████████▏             

[[0], [8], [0], [2], [0], [0], [0], [0], [0], [0]]
[[0], [0], [0], [2], [0], [0], [0], [0], [0], [0]]



Validation:  50%|█████████████████████████████████████████████████▌                                                 | 403/805 [01:08<01:09,  5.82it/s][A
Validation:  50%|█████████████████████████████████████████████████▋                                                 | 404/805 [01:09<01:08,  5.86it/s][A
Validation:  50%|█████████████████████████████████████████████████▊                                                 | 405/805 [01:09<01:07,  5.91it/s][A
Validation:  50%|█████████████████████████████████████████████████▉                                                 | 406/805 [01:09<01:07,  5.92it/s][A
Validation:  51%|██████████████████████████████████████████████████                                                 | 407/805 [01:09<01:07,  5.91it/s][A
Validation:  51%|██████████████████████████████████████████████████▏                                                | 408/805 [01:09<01:07,  5.90it/s][A
Validation:  51%|██████████████████████████████████████████████████▎       

[[5], [0], [0], [0], [0], [0], [0], [0], [0], [6]]
[[0], [0], [0], [0], [0], [8], [0], [0], [0], [0]]



Validation:  56%|███████████████████████████████████████████████████████▋                                           | 453/805 [01:17<00:59,  5.95it/s][A
Validation:  56%|███████████████████████████████████████████████████████▊                                           | 454/805 [01:17<00:59,  5.90it/s][A
Validation:  57%|███████████████████████████████████████████████████████▉                                           | 455/805 [01:17<00:59,  5.89it/s][A
Validation:  57%|████████████████████████████████████████████████████████                                           | 456/805 [01:17<00:59,  5.85it/s][A
Validation:  57%|████████████████████████████████████████████████████████▏                                          | 457/805 [01:18<00:58,  5.91it/s][A
Validation:  57%|████████████████████████████████████████████████████████▎                                          | 458/805 [01:18<00:59,  5.86it/s][A
Validation:  57%|████████████████████████████████████████████████████████▍ 

[[7], [8], [9], [9], [9], [9], [9], [5], [9], [9]]
[[7, 4], [8], [9], [0], [0], [10, 9], [5], [5], [0], [0]]



Validation:  62%|█████████████████████████████████████████████████████████████▊                                     | 503/805 [01:25<00:52,  5.77it/s][A
Validation:  63%|█████████████████████████████████████████████████████████████▉                                     | 504/805 [01:26<00:51,  5.80it/s][A
Validation:  63%|██████████████████████████████████████████████████████████████                                     | 505/805 [01:26<00:51,  5.82it/s][A
Validation:  63%|██████████████████████████████████████████████████████████████▏                                    | 506/805 [01:26<00:50,  5.89it/s][A
Validation:  63%|██████████████████████████████████████████████████████████████▎                                    | 507/805 [01:26<00:51,  5.83it/s][A
Validation:  63%|██████████████████████████████████████████████████████████████▍                                    | 508/805 [01:26<00:51,  5.82it/s][A
Validation:  63%|██████████████████████████████████████████████████████████

[[7], [3], [0], [0], [7], [5], [0], [7], [7], [0]]
[[7], [0], [0], [0], [0], [5], [0], [7], [7], [0]]



Validation:  69%|████████████████████████████████████████████████████████████████████                               | 553/805 [01:34<00:42,  5.87it/s][A
Validation:  69%|████████████████████████████████████████████████████████████████████▏                              | 554/805 [01:34<00:42,  5.84it/s][A
Validation:  69%|████████████████████████████████████████████████████████████████████▎                              | 555/805 [01:34<00:42,  5.85it/s][A
Validation:  69%|████████████████████████████████████████████████████████████████████▍                              | 556/805 [01:34<00:42,  5.87it/s][A
Validation:  69%|████████████████████████████████████████████████████████████████████▌                              | 557/805 [01:35<00:42,  5.87it/s][A
Validation:  69%|████████████████████████████████████████████████████████████████████▌                              | 558/805 [01:35<00:42,  5.83it/s][A
Validation:  69%|██████████████████████████████████████████████████████████

[[0], [0], [9], [0], [0], [0], [5], [0], [0], [9]]
[[9], [0], [9], [0], [0], [9], [5], [0], [0], [9]]



Validation:  75%|██████████████████████████████████████████████████████████████████████████▏                        | 603/805 [01:42<00:34,  5.94it/s][A
Validation:  75%|██████████████████████████████████████████████████████████████████████████▎                        | 604/805 [01:43<00:33,  5.96it/s][A
Validation:  75%|██████████████████████████████████████████████████████████████████████████▍                        | 605/805 [01:43<00:33,  5.93it/s][A
Validation:  75%|██████████████████████████████████████████████████████████████████████████▌                        | 606/805 [01:43<00:33,  5.94it/s][A
Validation:  75%|██████████████████████████████████████████████████████████████████████████▋                        | 607/805 [01:43<00:33,  5.87it/s][A
Validation:  76%|██████████████████████████████████████████████████████████████████████████▊                        | 608/805 [01:43<00:33,  5.90it/s][A
Validation:  76%|██████████████████████████████████████████████████████████

[[0], [5], [0], [0], [9], [0], [7], [1], [1], [0]]
[[9], [5], [0], [2, 9], [0], [0], [1, 7], [1], [1, 2], [2]]



Validation:  81%|████████████████████████████████████████████████████████████████████████████████▎                  | 653/805 [01:51<00:26,  5.82it/s][A
Validation:  81%|████████████████████████████████████████████████████████████████████████████████▍                  | 654/805 [01:51<00:25,  5.87it/s][A
Validation:  81%|████████████████████████████████████████████████████████████████████████████████▌                  | 655/805 [01:51<00:25,  5.85it/s][A
Validation:  81%|████████████████████████████████████████████████████████████████████████████████▋                  | 656/805 [01:51<00:25,  5.85it/s][A
Validation:  82%|████████████████████████████████████████████████████████████████████████████████▊                  | 657/805 [01:52<00:25,  5.87it/s][A
Validation:  82%|████████████████████████████████████████████████████████████████████████████████▉                  | 658/805 [01:52<00:25,  5.80it/s][A
Validation:  82%|██████████████████████████████████████████████████████████

[[0], [0], [0], [0], [5], [0], [5], [2], [0], [0]]
[[0], [0], [0], [0], [5], [0], [5], [2], [0], [0]]



Validation:  87%|██████████████████████████████████████████████████████████████████████████████████████▍            | 703/805 [02:00<00:17,  5.89it/s][A
Validation:  87%|██████████████████████████████████████████████████████████████████████████████████████▌            | 704/805 [02:00<00:17,  5.82it/s][A
Validation:  88%|██████████████████████████████████████████████████████████████████████████████████████▋            | 705/805 [02:00<00:17,  5.79it/s][A
Validation:  88%|██████████████████████████████████████████████████████████████████████████████████████▊            | 706/805 [02:00<00:17,  5.78it/s][A
Validation:  88%|██████████████████████████████████████████████████████████████████████████████████████▉            | 707/805 [02:00<00:17,  5.76it/s][A
Validation:  88%|███████████████████████████████████████████████████████████████████████████████████████            | 708/805 [02:00<00:16,  5.79it/s][A
Validation:  88%|██████████████████████████████████████████████████████████

[[1], [9], [9], [5], [0], [0], [9], [9], [0], [9]]
[[0], [9], [0], [5], [0], [0], [0], [0], [0], [5]]



Validation:  94%|████████████████████████████████████████████████████████████████████████████████████████████▌      | 753/805 [02:08<00:08,  5.85it/s][A
Validation:  94%|████████████████████████████████████████████████████████████████████████████████████████████▋      | 754/805 [02:08<00:08,  5.83it/s][A
Validation:  94%|████████████████████████████████████████████████████████████████████████████████████████████▊      | 755/805 [02:08<00:08,  5.86it/s][A
Validation:  94%|████████████████████████████████████████████████████████████████████████████████████████████▉      | 756/805 [02:09<00:08,  5.86it/s][A
Validation:  94%|█████████████████████████████████████████████████████████████████████████████████████████████      | 757/805 [02:09<00:08,  5.85it/s][A
Validation:  94%|█████████████████████████████████████████████████████████████████████████████████████████████▏     | 758/805 [02:09<00:07,  5.89it/s][A
Validation:  94%|██████████████████████████████████████████████████████████

[[10], [10], [0], [0], [10], [10], [4], [4], [4], [0]]
[[0], [0], [0], [3], [0], [10, 9], [0], [0], [4], [0]]



Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████▊| 803/805 [02:17<00:00,  5.92it/s][A
Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 805/805 [02:17<00:00,  5.86it/s]


F1-Score: 0.5856307435254803


Epoch: 0:  33%|████████████████████████████████▏                                                                  | 1056/3244 [21:23<20:29,  1.78it/s]

KeyboardInterrupt: ignored

## Test Set Prediction

In [None]:
    ### Dumped out the results for test dataset.

    ### You need to write your code here to dump out the dataset using "test_dataloader".
    ### you need to write out all your model predictions into a file "output.json".
    ### Each line of the "output.json" is the model prediction for the sentence.

test_preds = []

prompt_model.load_state_dict(torch.load("/content/checkpoint_best.pt"))
allpreds, alllabels = [], []
# Testing:
test_progress = tqdm.tqdm(total=len(test_dataloader), ncols=150, desc="Testing: ")


for step, inputs in enumerate(test_dataloader):
  inputs = to_device(inputs, device)
  logits = prompt_model(inputs)
  labels = inputs['label']
  label_list = convert_labels_to_list(labels)
  pred_labels = predict(logits)

  alllabels.extend(label_list)
  allpreds.extend(pred_labels)
  test_progress.update(1)

test_progress.close()

p, r, f, total = evaluation(alllabels, allpreds, vocabulary)
test_preds = allpreds
print("F1-Score: " + str(f))
with open("test_results.json", 'w', encoding='utf-8') as f_out:
  f_out.write(json.dumps(total, indent=4))


    ### You may find "inv_vocabulary" useful here.

    ### Each line should be in the following format:
    ###    {"predictions": ["Catastrophe", "Conquering"]}
    ###    {"predictions": ["Social_event"]}
    ###    {"predictions": []}

    ### Note that the sentence order for your output file should be the same with the original file!

Testing:  56%|█████████████████████████████████████████████████████████▌                                            | 530/940 [01:40<01:17,  5.32it/s][A
Testing:  56%|█████████████████████████████████████████████████████████▌                                            | 531/940 [01:40<01:17,  5.30it/s][A
Testing:  57%|█████████████████████████████████████████████████████████▋                                            | 532/940 [01:40<01:16,  5.32it/s][A
Testing:  57%|█████████████████████████████████████████████████████████▊                                            | 533/940 [01:40<01:17,  5.27it/s][A
Testing:  57%|█████████████████████████████████████████████████████████▉                                            | 534/940 [01:41<01:16,  5.28it/s][A
Testing:  57%|██████████████████████████████████████████████████████████                                            | 535/940 [01:41<01:16,  5.27it/s][A
Testing:  57%|██████████████████████████████████████████████████████████▏   

F1-Score: 0.0


In [None]:
class_dict = {'None': 0, 'Catastrophe': 1, 'Causation': 2, 'Motion': 3, 'Hostile_encounter': 4, 'Process_start': 5, 'Attack': 6, 'Killing': 7, 'Conquering': 8, 'Social_event': 9, 'Competition': 10}
inv_vocab = {v:k for k,v in class_dict.items()}
output = []
for i in range(0,len(test_preds)):
  prediction = test_preds[i][0]
  ## getting string type
  if prediction == 0:
    output.append({"prediction":[]})
  else:
    output.append({"prediction":[inv_vocab[prediction]]})



In [None]:
output[:10]

[{'prediction': ['Process_start']},
 {'prediction': []},
 {'prediction': []},
 {'prediction': ['Process_start']},
 {'prediction': ['Process_start']},
 {'prediction': []},
 {'prediction': ['Process_start']},
 {'prediction': ['Hostile_encounter']},
 {'prediction': []},
 {'prediction': []}]

In [None]:
# convert into json
# file name is mydata
with open("output_rcsalvi2.json", "w") as fin:
    json.dump(output, fin)