In [1]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
import json
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq
import torch
import os
import evaluate
import numpy as np
import glob
import ast
import re

In [13]:
torch.cuda.empty_cache()

path = "reuters-data-for-ACL-analysis/"
li = []
all_files = glob.glob(os.path.join(path, "*.csv"))
for filename in all_files:
    df = pd.read_csv(filename, index_col=None, header=0)
    li.append(df)

df = pd.concat(li, axis=0, ignore_index=True)

In [14]:
df = df[["date", "description", "keywords"]]
df.drop_duplicates(inplace=True)
print (df.shape)
df.head()

(212, 3)


Unnamed: 0,date,description,keywords
0,2023-05-31T12:00:00.000Z,The president of the Confederation of British ...,"MTPIX,BACT,BIZ,BOSS1,ECO,GEN,MCE,MNGISS,PLCY,P..."
1,2023-05-31T12:00:00.000Z,Sales of own-label products at British superma...,"BACT,BIZ,CMPNY,ECI,ECO,FDRT,GEN,INFL,MCE,NCYC,..."
2,2023-05-30T12:00:00.000Z,Sentiment among British businesses fell for th...,"BOE,BSENT,CEN,ECI,ECO,GDP,INFL,INT,MCE,PMI,GB,..."
3,2023-05-30T12:00:00.000Z,Britain's competition regulator told supermark...,"EF:BUSINESS-MACROMATTERS,CMPNY,ECI,ECO,FDRT,FO..."
4,2023-05-31T12:00:00.000Z,UK's FTSE 100 slid on Wednesday to a two-month...,"MKTREP,REP,CMPNY,DBT,EUB,FIN,FINS,FINS08,HOT,I..."


In [15]:
# df.to_csv("conll23_test_examples.csv", index=False)

In [4]:
prefix = """List all relations of the following type in the given text and provide reasonable explanations for your answers - \n1. Kill: Entity A killed Entity B.\n2. Work_For: Entity A works for Entity B.\n3. Located_In: Entity A is located in Entity B.\n4. Live_In: Entity A lives in Entity B.\n5. OrgBased_In: Entity A is an organization based in Entity B.\n\n"""

print (prefix)

List all relations of the following type in the given text and provide reasonable explanations for your answers - 
1. Kill: Entity A killed Entity B.
2. Work_For: Entity A works for Entity B.
3. Located_In: Entity A is located in Entity B.
4. Live_In: Entity A lives in Entity B.
5. OrgBased_In: Entity A is an organization based in Entity B.




In [5]:
tuned_model = AutoModelForSeq2SeqLM.from_pretrained("/home/wadhwa.s/gpt3/models/conll/checkpoint-2000/", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("/home/wadhwa.s/gpt3/models/conll/checkpoint-2000/")

In [6]:
ip = []
date = []
generated = []
nc_count = 0

for ix, row in tqdm(df.iterrows(), total=df.shape[0]):
    text = prefix + row["description"]
    inputs = tokenizer(text, return_tensors="pt").input_ids.to("cuda")
    outputs = tuned_model.generate(inputs, max_new_tokens=256, do_sample=False)
    out = tokenizer.decode(outputs[0], skip_special_tokens=True)
    torch.cuda.empty_cache()
    try:
        generated.append(out.split(" Explanation: ")[0])
        ip.append(row["description"])
        date.append(row["date"])
        # gold.append(row["gold_labels"])
        # print ("PRED: " + out)
    except:
        nc_count += 1
        print (row["text"])
        # print ("GOLD: " + row["gold_labels"])
        print ("PRED ---- NON CONFORMING OUTPUT: " + out)
        print ("\n----------------\n")

100%|██████████| 212/212 [06:56<00:00,  1.97s/it]


In [7]:
df = pd.DataFrame({"text": ip, 
                   "date": date, 
                   "generated": generated})
# df.to_csv("conll_flan_explanations_generated_ckpt500.csv", index=False)
print (df.shape)

(212, 3)


In [8]:
invalid_count = 0
total_count = 0
valid_parsed = 0
for ix, row in df.iterrows():
    # print ("GOLD: " + row["gold_labels"])
    try:
        pred = ast.literal_eval(row["generated"])
        total_count += len(pred)
        valid_parsed += 1
        for relation in pred:
            if relation[1] not in valid:
                print ("TEXT: ", row["text"])
                print ("PRED: " + relation)
                print ("\n*********************\n")
    except:
        invalid_count += 1
        print ("TEXT: ", row["text"])
        print ("NON CONFORMING PRED: ", row["generated"])
        print ("\n----------------\n")

TEXT:  The president of the Confederation of British Industry (CBI), which is fighting for its survival after a series of workplace misconduct incidents, will step down early next year, the lobbying organisation said on Wednesday.
NON CONFORMING PRED:  [['Confederation of British Industry', 'OrgBased_In', 'Britain'], ['CBI', 'OrgBased_In', 'Britain']]

----------------

TEXT:  Sales of own-label products at British supermarkets have grown at double the speed of branded goods in 2023, data from market researchers NIQ showed on Wednesday, as customers adjust to soaring prices.
NON CONFORMING PRED:  [['NIQ', 'OrgBased_In', 'Britain']]

----------------

TEXT:  Sentiment among British businesses fell for the first time in three months in May as firms were less optimistic about the economy and their trading prospects despite some signs of resilience in the economy, a survey showed on Wednesday.
NON CONFORMING PRED:  [['British', 'Located_In', 'UK']]

----------------

TEXT:  Britain's compe

In [9]:
print (invalid_count, total_count, valid_parsed)

212 250 211


In [10]:
for ix, row in df.iterrows():
    print ("DATE: ", row["date"])
    print ("TEXT: ", row["text"])
    print ("GENERATED RELATIONS: ", row["generated"])
    print ("\n-----------------------------------------\n")

DATE:  2023-05-31T12:00:00.000Z
TEXT:  The president of the Confederation of British Industry (CBI), which is fighting for its survival after a series of workplace misconduct incidents, will step down early next year, the lobbying organisation said on Wednesday.
GENERATED RELATIONS:  [['Confederation of British Industry', 'OrgBased_In', 'Britain'], ['CBI', 'OrgBased_In', 'Britain']]

-----------------------------------------

DATE:  2023-05-31T12:00:00.000Z
TEXT:  Sales of own-label products at British supermarkets have grown at double the speed of branded goods in 2023, data from market researchers NIQ showed on Wednesday, as customers adjust to soaring prices.
GENERATED RELATIONS:  [['NIQ', 'OrgBased_In', 'Britain']]

-----------------------------------------

DATE:  2023-05-30T12:00:00.000Z
TEXT:  Sentiment among British businesses fell for the first time in three months in May as firms were less optimistic about the economy and their trading prospects despite some signs of resilien

In [12]:
df.to_csv("conll23_flan_test_untiljun1.csv", index=False)