In [1]:
import pickle
import os
import sys
from datetime import datetime
import math
import json
import re
import csv
import numpy as np
from evaluate import load
import random
from bs4 import BeautifulSoup as BS
import copy
from sklearn.metrics import classification_report

2023-07-05 12:07:47.742090: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-07-05 12:07:47.773423: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-07-05 12:07:47.774150: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

In [4]:
random.seed(13)
torch.manual_seed(13)
np.random.seed(13)
torch.backends.cudnn.benchmarks = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [5]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda:0


In [6]:
fpath = '/homes/rpujari/scratch_ml/DARPA/'

In [7]:
plutchik =  ['anger', 'fear', 'sadness', 'disgust', 'surprise', 'anticipation', 'trust', 'joy', 'neutral']

### Loding LDC data

In [8]:
ldc_dpath = '/homes/rpujari/scratch0_ml/ldc_data/'

In [9]:
def get_ldc_emotion_data(data_version):
    tfnames = os.listdir(ldc_dpath + data_version + '/source_data/text/txt/')
    tfnames = [fn[:-4] for fn in tfnames]

    #read and save all segments from segments.tab 
    #(all_segments: file_ID -> segment_name -> seg_bounds, rev_segments: -> file_ID -> seg_bounds -> segment_name)
    as_c = 0
    all_segments = {}
    rev_segments = {}
    seg_data = list(csv.reader(open(ldc_dpath + data_version + '/docs/segments.tab'), delimiter='\t'))
    #drop header
    for row in seg_data[1:]:
        #each row have file_ID, segment_name, start_char/start_secs, end_char/end_secs
        ann_fname, seg_name, sc, ec = row
        seg_bounds = (float(sc), float(ec))
        if ann_fname in tfnames:
            if ann_fname not in all_segments:
                all_segments[ann_fname] = {}
            all_segments[ann_fname][seg_name] = seg_bounds
            
            if ann_fname not in rev_segments:
                rev_segments[ann_fname] = {}
            rev_segments[ann_fname][seg_bounds] = seg_name

    for tfname in all_segments:
        as_c += len(all_segments[tfname])

    #read and save annotations from perfect_submission for ED (annotations: file_ID -> seg_bounds -> emotion)
    an_c = 0
    annotations = {}
    label_dist = {}
    for tfname in tfnames:
        if tfname not in annotations:
            annotations[tfname] = {}
        tf_anns = list(csv.reader(open(ldc_dpath + data_version + '/perfect_submissions/ED/' + tfname + '.tab'), delimiter='\t'))
        #drop header
        for ann in tf_anns[1:]:
            #each row has file_ID, annotated_emotion, start_char/start_secs, end_char/end_secs, llr (place_holder)
            ann_fname, em, sc, ec, llr = ann
            seg_bounds = (float(sc), float(ec))
            if seg_bounds not in annotations[tfname]:
                annotations[tfname][seg_bounds] = []
            annotations[tfname][seg_bounds].append(em)
            if em not in label_dist:
                label_dist[em] = 0
            label_dist[em] += 1
        an_c += len(annotations[tfname])
        
    #randomly sample neutral segments
    num_neutral_samples = int(np.mean([v for k, v in label_dist.items()]))
    neutral_segments = []
    for tfname in rev_segments:
        for seg_bounds in rev_segments[tfname]:
            if seg_bounds not in annotations[tfname]:
                neutral_segments.append((tfname, seg_bounds))
    sel_neutral_segments = random.sample(neutral_segments, min(num_neutral_samples, len(neutral_segments)))
    label_dist['neutral'] = len(sel_neutral_segments)
    
    #add sampled segments to annotations
    for tfname, seg_bounds in sel_neutral_segments:
        if tfname not in annotations:
            annotations[tfname] = {}
        annotations[tfname][seg_bounds] = ['neutral']

    return all_segments, rev_segments, annotations, label_dist

##### LDC Loading to memory

In [10]:
ldc_train = get_ldc_emotion_data('v5.1')
ldc_test = get_ldc_emotion_data('unsequestered')

In [11]:
ldc_intents_p1 = open('/homes/rpujari/scratch_ml/shared_models/gpt-neox/ldc_prompts/output_files/ldc_intent_outputs_p1.txt', 'r').readlines()
ldc_intents_p2 = open('/homes/rpujari/scratch_ml/shared_models/gpt-neox/ldc_prompts/output_files/ldc_intent_outputs_p2.txt', 'r').readlines()
ldc_intents = ldc_intents_p1 + ldc_intents_p2
ldc_intents = [i for i in ldc_intents if i.strip()]

intent_dict = {}
for idx, intent in enumerate(ldc_intents):
    d = json.loads(intent)
    id_ = d['context'].strip().split('\t')[0].split()[1]
    i = d['text'].strip().split('\t')[0].split('   ')[0].strip().split('\n')[0].strip()
    intent_dict[id_] = i
    
print(len(intent_dict))

29477


In [12]:
ldc_dset = {}
ldc_dset["train"] = {'text': [], 'context': [], 'intent': [], 'label': []}
ldc_dset["valid"] = {'text': [], 'context': [], 'intent': [], 'label': []}
ldc_dset["test"] = {'text': [], 'context': [], 'intent': [], 'label': []}

context_len = 4

segs, rev_segs, anns, label_dist = ldc_train
multi_labels = 0
tot = 0
tlen = len(anns)
fcount = 0
t1 = datetime.now()
for tfname in anns:
    all_bounds = list(anns[tfname].keys())
    all_bounds = sorted(all_bounds, key=lambda x:int(x[0]))
    ftext = open(ldc_dpath + 'v5.1/source_data/text/txt/' + tfname + '.txt').read()
    all_turns = []
    for i, seg_bounds in enumerate(all_bounds):
        if True: #seg_bounds in rev_segs[tfname]:
            b, e = seg_bounds
            b = int(b)
            e = int(e)
            # seg_name = rev_segs[tfname][seg_bounds]
            all_turns.append(ftext[b:e+1])
            toss = random.random()
            if toss <= 0.8:
                if len(anns[tfname][seg_bounds]) > 1:
                    multi_labels += 1
                else:
                    ldc_dset["train"]["text"].append(ftext[b:e+1])
                    ldc_dset["train"]["intent"].append(ftext[b:e+1])
                    #ldc_dset["train"]["intent"].append(intent_dict[seg_name])
                    ldc_dset["train"]["context"].append('\n'.join(all_turns[max(0, i-context_len):i]))
                    ldc_dset["train"]["label"].append(plutchik.index(anns[tfname][seg_bounds][0]))
            else:
                if len(anns[tfname][seg_bounds]) > 1:
                    multi_labels += 1
                else:
                    ldc_dset["valid"]["text"].append(ftext[b:e+1])
                    ldc_dset["valid"]["intent"].append(ftext[b:e+1])
                    #ldc_dset["valid"]["intent"].append(intent_dict[seg_name])
                    ldc_dset["valid"]["context"].append('\n'.join(all_turns[max(0, i-context_len):i]))
                    ldc_dset["valid"]["label"].append(plutchik.index(anns[tfname][seg_bounds][0]))
    tot += len(anns[tfname])
    fcount += 1
    t2 = datetime.now()
print(multi_labels, tot)
print(label_dist)

segs, rev_segs, anns, label_dist = ldc_test
multi_labels = 0
tot = 0
for tfname in anns:
    all_bounds = list(anns[tfname].keys())
    all_bounds = sorted(all_bounds, key=lambda x:int(x[0]))
    ftext = open(ldc_dpath + 'unsequestered/source_data/text/txt/' + tfname + '.txt').read()
    all_turns = []
    for i, seg_bounds in enumerate(all_bounds):
        if True: #seg_bounds in rev_segs[tfname]:
            b, e = seg_bounds
            b = int(b)
            e = int(e)
            # seg_name = rev_segs[tfname][seg_bounds]
            all_turns.append(ftext[b:e+1])
            if len(anns[tfname][seg_bounds]) > 1:
                multi_labels += 1
            else:
                ldc_dset["test"]["text"].append(ftext[b:e+1])
                ldc_dset["test"]["intent"].append(ftext[b:e+1])
                #ldc_dset["test"]["intent"].append(intent_dict[seg_name])
                ldc_dset["test"]["context"].append('\n'.join(all_turns[max(0, i-context_len):i]))
                ldc_dset["test"]["label"].append(plutchik.index(anns[tfname][seg_bounds][0]))
    tot += len(anns[tfname])
print(multi_labels, tot)
print(label_dist)

print(len(ldc_dset['train']['text']), len(ldc_dset['valid']['text']), len(ldc_dset['test']['text'])) 

60 2176
{'anger': 268, 'joy': 622, 'anticipation': 185, 'sadness': 216, 'disgust': 303, 'fear': 118, 'surprise': 255, 'trust': 23, 'neutral': 248}
22 622
{'fear': 38, 'disgust': 114, 'trust': 41, 'joy': 137, 'anticipation': 65, 'surprise': 81, 'sadness': 54, 'anger': 43, 'neutral': 71}
1682 434 600


### Loading MPDD Data

In [13]:
mpdd_dialogues = json.load(open(fpath + 'mpdd/dialogue.json'))
mpdd_metada = json.load(open(fpath + 'mpdd/metadata.json'))

In [14]:
emotion_map = {
    'fear': 'fear',
    'angry': 'anger',
    'disgust': 'disgust',
    'sadness': 'sadness',
    'happiness': 'joy',
    'surprise': 'surprise',
    'neutral': 'neutral'
}

In [15]:
mpdd_dset = {}
mpdd_dset["train"] = {'text': [], 'context': [], 'intent': [], 'label': []}
mpdd_dset["valid"] = {'text': [], 'context': [], 'intent': [], 'label': []}
mpdd_dset["test"] = {'text': [], 'context': [], 'intent': [], 'label': []}

In [16]:
mpdd_intent = {}
with open('/homes/rpujari/scratch_ml/shared_models/gpt-neox/mpdd_prompts/output_files/mpdd_intent_outputs.txt', 'r') as infile:
    fc = infile.readlines()
    for line in fc:
        d = json.loads(line.strip())
        id_ = d['context'].split('\t')[0].split()[1].strip()
        intent = d['text'].split('\t')[0].split('    ')[0].strip().split('\n')[0]
        mpdd_intent[id_] = intent

In [17]:
for id_ in mpdd_dialogues:
    dialogue = mpdd_dialogues[id_]
    context = ''
    for i, turn in enumerate(dialogue):
        toss = random.random()
        if toss <= 0.7:
            mpdd_dset["train"]['text'].append(turn['utterance'])
            mpdd_dset["train"]['label'].append(plutchik.index(emotion_map[turn['emotion']]))
            mpdd_dset["train"]['intent'].append(mpdd_intent[id_ + '-' + str(i)])
            mpdd_dset["train"]['context'].append(context.strip())
        elif toss <= 0.9:
            mpdd_dset["valid"]['text'].append(turn['utterance'])
            mpdd_dset["valid"]['label'].append(plutchik.index(emotion_map[turn['emotion']]))
            mpdd_dset["valid"]['intent'].append(mpdd_intent[id_ + '-' + str(i)])
            mpdd_dset["valid"]['context'].append(context.strip())
        else:
            mpdd_dset["test"]['text'].append(turn['utterance'])
            mpdd_dset["test"]['label'].append(plutchik.index(emotion_map[turn['emotion']]))
            mpdd_dset["test"]['intent'].append(mpdd_intent[id_ + '-' + str(i)])
            mpdd_dset["test"]['context'].append(context.strip())
        context += turn['speaker'] + '(' + emotion_map[turn['emotion']] + '): ' + turn['utterance'] + '\n'

### Loading CPED Data

In [18]:
cped_dset = {}
cped_dset["train"] = {'text': [], 'context': [], 'intent': [], 'label': []}
cped_dset["valid"] = {'text': [], 'context': [], 'intent': [], 'label': []}
cped_dset["test"] = {'text': [], 'context': [], 'intent': [], 'label': []}

In [19]:
# plutchik = ['anger', 'fear', 'sadness', 'disgust', 'surprise', 'anticipation', 'trust', 'joy', 'neutral']

# cped_emotion_map = {
#     'grateful': 'trust',
#     'neutral': 'neutral',
#     'relaxed': 'joy',
#     'positive-other': 'joy',
#     'negative-other': 'sadness',
#     'astonished': 'surprise',
#     'sadness': 'sadness',
#     'fear': 'fear',
#     'worried': 'anticipation',
#     'anger': 'anger',
#     'depress': 'sadness',
#     'disgust': 'disgust',
#     'happy': 'joy'}

cped_emotion_map = {
    'grateful': 'trust',
    'neutral': 'neutral',
    'astonished': 'surprise',
    'sadness': 'sadness',
    'fear': 'fear',
    'worried': 'anticipation',
    'anger': 'anger',
    'depress': 'sadness',
    'disgust': 'disgust',
    'happy': 'joy'}

In [20]:
fp = open(fpath + '/CPED/data/CPED/train_split.csv')
fr = list(csv.reader(fp))
i = 0
header = None
train_data = {}
for row in fr:
    if i > 0:
        tv_id = row[0]
        d_id = row[header.index('Dialogue_ID')]
        ut_id = row[header.index('Utterance_ID')]
        speaker = row[header.index('Speaker')]
        ut = row[header.index('Utterance')]
        em = row[header.index('Emotion')]
        if em in cped_emotion_map:
            if tv_id not in train_data:
                train_data[tv_id] = {}
            if d_id not in train_data[tv_id]:
                train_data[tv_id][d_id] = {}
            train_data[tv_id][d_id][ut_id] = (speaker, ut, em)                                                 
    else:
        header = row
        print(row)
    i += 1

['\ufeffTV_ID', 'Dialogue_ID', 'Utterance_ID', 'Speaker', 'Gender', 'Age', 'Neuroticism', 'Extraversion', 'Openness', 'Agreeableness', 'Conscientiousness', 'Scene', 'FacePosition_LU', 'FacePosition_RD', 'Sentiment', 'Emotion', 'DA', 'Utterance']


In [21]:
fp = open(fpath + '/CPED/data/CPED/valid_split.csv')
fr = list(csv.reader(fp))
i = 0
header = None
valid_data = {}
for row in fr:
    if i > 0:
        tv_id = row[0]
        d_id = row[header.index('Dialogue_ID')]
        ut_id = row[header.index('Utterance_ID')]
        speaker = row[header.index('Speaker')]
        ut = row[header.index('Utterance')]
        em = row[header.index('Emotion')]
        if em in cped_emotion_map:
            if tv_id not in valid_data:
                valid_data[tv_id] = {}
            if d_id not in valid_data[tv_id]:
                valid_data[tv_id][d_id] = {}
            valid_data[tv_id][d_id][ut_id] = (speaker, ut, em)                                                  
    else:
        header = row
        print(row)
    i += 1

['\ufeffTV_ID', 'Dialogue_ID', 'Utterance_ID', 'Speaker', 'Gender', 'Age', 'Neuroticism', 'Extraversion', 'Openness', 'Agreeableness', 'Conscientiousness', 'Scene', 'FacePosition_LU', 'FacePosition_RD', 'Sentiment', 'Emotion', 'DA', 'Utterance']


In [22]:
fp = open(fpath + '/CPED/data/CPED/test_split.csv')
fr = list(csv.reader(fp))
i = 0
header = None
test_data = {}
for row in fr:
    if i > 0:
        tv_id = row[0]
        d_id = row[header.index('Dialogue_ID')]
        ut_id = row[header.index('Utterance_ID')]
        speaker = row[header.index('Speaker')]
        ut = row[header.index('Utterance')]
        em = row[header.index('Emotion')]
        if em in cped_emotion_map:
            if tv_id not in test_data:
                test_data[tv_id] = {}
            if d_id not in test_data[tv_id]:
                test_data[tv_id][d_id] = {}
            test_data[tv_id][d_id][ut_id] = (speaker, ut, em)                                                  
    else:
        header = row
        print(row)
    i += 1

['\ufeffTV_ID', 'Dialogue_ID', 'Utterance_ID', 'Speaker', 'Gender', 'Age', 'Neuroticism', 'Extraversion', 'Openness', 'Agreeableness', 'Conscientiousness', 'Scene', 'FacePosition_LU', 'FacePosition_RD', 'Sentiment', 'Emotion', 'DA', 'Utterance']


In [23]:
cped_intent_path = '/homes/rpujari/scratch_ml/shared_models/gpt-neox/cped_prompts/outputs_intent/'
cped_intents = {}
cped_ifnames = os.listdir(cped_intent_path)
t = 0
for ifname in cped_ifnames:
    ilines = open(cped_intent_path + ifname, 'r').readlines()
    for iline in ilines:
        if iline.strip():
            d = json.loads(iline.strip())
            try:
                id_ = d['context'].strip().split('\t')[0].split()[1].strip()
                intent = d['text'].split('\t')[0].split('    ')[0].strip().split('\n')[0]
                cped_intents[id_] = intent
            except:
                pass
            t += 1
print(t, len(cped_intents))

132770 132763


In [24]:
das = set()
for row in fr[1:]:
    das.add(row[-2])
print(das)

{'irony', 'greeting', 'comfort', 'thanking', 'statement-opinion', 'answer', 'acknowledge', 'agreement', 'command', 'interjection', 'question', 'apology', 'statement-non-opinion', 'disagreement', 'appreciation', 'conventional-closing', 'other', 'reject', 'quotation'}


In [25]:
t = 0
for tv_id in train_data:
    for d_id in train_data[tv_id]:
        ut_ids = sorted(list(train_data[tv_id][d_id].keys()))
        context = []
        for i, ut_id in enumerate(ut_ids):
            turn = train_data[tv_id][d_id][ut_id]
            if ut_id in cped_intents and turn[2] in cped_emotion_map:    
                cped_dset['train']['text'].append(turn[1])
                cped_dset['train']['intent'].append(cped_intents[ut_id])
                cped_dset['train']['context'].append('\n'.join(context[-4:]))
                cped_dset['train']['label'].append(plutchik.index(cped_emotion_map[turn[2]]))
            context.append(turn[0] + '(' + turn[2] + '): ' + turn[1])
            t += 1
print(len(cped_dset['train']['text']), t)            

72850 72850


In [26]:
t = 0
for tv_id in valid_data:
    for d_id in valid_data[tv_id]:
        ut_ids = sorted(list(valid_data[tv_id][d_id].keys()))
        context = []
        for i, ut_id in enumerate(ut_ids):
            turn = valid_data[tv_id][d_id][ut_id]
            if ut_id in cped_intents and turn[2] in cped_emotion_map:    
                cped_dset['valid']['text'].append(turn[1])
                cped_dset['valid']['intent'].append(cped_intents[ut_id])
                cped_dset['valid']['context'].append('\n'.join(context[-4:]))
                cped_dset['valid']['label'].append(plutchik.index(cped_emotion_map[turn[2]]))
            context.append(turn[0] + '(' + turn[2] + '): ' + turn[1])
            t += 1
print(len(cped_dset['valid']['text']), t)            

8126 8126


In [27]:
t = 0
for tv_id in test_data:
    for d_id in test_data[tv_id]:
        ut_ids = sorted(list(test_data[tv_id][d_id].keys()))
        context = []
        for i, ut_id in enumerate(ut_ids):
            turn = test_data[tv_id][d_id][ut_id]
            if ut_id in cped_intents and turn[2] in cped_emotion_map:    
                cped_dset['test']['text'].append(turn[1])
                cped_dset['test']['intent'].append(cped_intents[ut_id])
                cped_dset['test']['context'].append('\n'.join(context[-4:]))
                cped_dset['test']['label'].append(plutchik.index(cped_emotion_map[turn[2]]))
            context.append(turn[0] + '(' + turn[2] + '): ' + turn[1])
            t += 1
print(len(cped_dset['test']['text']), t)            

21208 21208


In [28]:
print(len(cped_dset['train']['text']), len(cped_dset['valid']['text']), len(cped_dset['test']['text']))

72850 8126 21208


### Loading data to datasets

In [29]:
mpdd_train = Dataset.from_dict(mpdd_dset['train'])
mpdd_valid = Dataset.from_dict(mpdd_dset['valid'])
mpdd_test = Dataset.from_dict(mpdd_dset['test'])

cped_train = Dataset.from_dict(cped_dset['train'])
cped_valid = Dataset.from_dict(cped_dset['valid'])
cped_test = Dataset.from_dict(cped_dset['test'])

ldc_train = Dataset.from_dict(ldc_dset['train'])
ldc_valid = Dataset.from_dict(ldc_dset['valid'])
ldc_test = Dataset.from_dict(ldc_dset['test'])

In [30]:
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length",\
                     truncation=True, max_length=512)


tokenized_mpdd_train = mpdd_train.map(tokenize_function, batched=True)
tokenized_mpdd_valid = mpdd_valid.map(tokenize_function, batched=True)
tokenized_mpdd_test = mpdd_test.map(tokenize_function, batched=True)

tokenized_cped_train = cped_train.map(tokenize_function, batched=True)
tokenized_cped_valid = cped_valid.map(tokenize_function, batched=True)
tokenized_cped_test = cped_test.map(tokenize_function, batched=True)

tokenized_ldc_train = ldc_train.map(tokenize_function, batched=True)
tokenized_ldc_valid = ldc_valid.map(tokenize_function, batched=True)
tokenized_ldc_test = ldc_test.map(tokenize_function, batched=True)

Downloading:   0%|          | 0.00/19.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/269k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

  0%|          | 0/18 [00:00<?, ?ba/s]

  0%|          | 0/6 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/73 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/22 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [31]:
tokenized_mpdd_train = tokenized_mpdd_train.remove_columns(["text"])
tokenized_mpdd_train = tokenized_mpdd_train.remove_columns(["context"])
tokenized_mpdd_train = tokenized_mpdd_train.remove_columns(["intent"])
tokenized_mpdd_train = tokenized_mpdd_train.rename_column("label", "labels")
tokenized_mpdd_train.set_format("torch")

tokenized_mpdd_valid = tokenized_mpdd_valid.remove_columns(["text"])
tokenized_mpdd_valid = tokenized_mpdd_valid.remove_columns(["context"])
tokenized_mpdd_valid = tokenized_mpdd_valid.remove_columns(["intent"])
tokenized_mpdd_valid = tokenized_mpdd_valid.rename_column("label", "labels")
tokenized_mpdd_valid.set_format("torch")

tokenized_mpdd_test = tokenized_mpdd_test.remove_columns(["text"])
tokenized_mpdd_test = tokenized_mpdd_test.remove_columns(["context"])
tokenized_mpdd_test = tokenized_mpdd_test.remove_columns(["intent"])
tokenized_mpdd_test = tokenized_mpdd_test.rename_column("label", "labels")
tokenized_mpdd_test.set_format("torch")

tokenized_cped_train = tokenized_cped_train.remove_columns(["text"])
tokenized_cped_train = tokenized_cped_train.remove_columns(["context"])
tokenized_cped_train = tokenized_cped_train.remove_columns(["intent"])
tokenized_cped_train = tokenized_cped_train.rename_column("label", "labels")
tokenized_cped_train.set_format("torch")

tokenized_cped_valid = tokenized_cped_valid.remove_columns(["text"])
tokenized_cped_valid = tokenized_cped_valid.remove_columns(["context"])
tokenized_cped_valid = tokenized_cped_valid.remove_columns(["intent"])
tokenized_cped_valid = tokenized_cped_valid.rename_column("label", "labels")
tokenized_cped_valid.set_format("torch")

tokenized_cped_test = tokenized_cped_test.remove_columns(["text"])
tokenized_cped_test = tokenized_cped_test.remove_columns(["context"])
tokenized_cped_test = tokenized_cped_test.remove_columns(["intent"])
tokenized_cped_test = tokenized_cped_test.rename_column("label", "labels")
tokenized_cped_test.set_format("torch")

tokenized_ldc_train = tokenized_ldc_train.remove_columns(["text"])
tokenized_ldc_train = tokenized_ldc_train.remove_columns(["context"])
tokenized_ldc_train = tokenized_ldc_train.remove_columns(["intent"])
tokenized_ldc_train = tokenized_ldc_train.rename_column("label", "labels")
tokenized_ldc_train.set_format("torch")

tokenized_ldc_valid = tokenized_ldc_valid.remove_columns(["text"])
tokenized_ldc_valid = tokenized_ldc_valid.remove_columns(["context"])
tokenized_ldc_valid = tokenized_ldc_valid.remove_columns(["intent"])
tokenized_ldc_valid = tokenized_ldc_valid.rename_column("label", "labels")
tokenized_ldc_valid.set_format("torch")

tokenized_ldc_test = tokenized_ldc_test.remove_columns(["text"])
tokenized_ldc_test = tokenized_ldc_test.remove_columns(["context"])
tokenized_ldc_test = tokenized_ldc_test.remove_columns(["intent"])
tokenized_ldc_test = tokenized_ldc_test.rename_column("label", "labels")
tokenized_ldc_test.set_format("torch")

In [32]:
mpdd_train_dataset = tokenized_mpdd_train.shuffle(seed=13).select(range(5000))
mpdd_dev_dataset = tokenized_mpdd_valid.shuffle(seed=13).select(range(300))
mpdd_test_dataset = tokenized_mpdd_test

cped_train_dataset = tokenized_cped_train.shuffle(seed=13).select(range(5000))
cped_dev_dataset = tokenized_cped_valid.shuffle(seed=13).select(range(300))
cped_test_dataset = tokenized_cped_test

ldc_train_dataset = tokenized_ldc_train
ldc_dev_dataset = tokenized_ldc_valid
ldc_test_dataset = tokenized_ldc_test

In [33]:
mpdd_train_dataloader = DataLoader(mpdd_train_dataset, shuffle=True, batch_size=8)
mpdd_dev_dataloader = DataLoader(mpdd_dev_dataset, batch_size=8)
mpdd_test_dataloader = DataLoader(mpdd_test_dataset, batch_size=8)

cped_train_dataloader = DataLoader(cped_train_dataset, shuffle=True, batch_size=8)
cped_dev_dataloader = DataLoader(cped_dev_dataset, batch_size=8)
cped_test_dataloader = DataLoader(cped_test_dataset, batch_size=8)

ldc_train_dataloader = DataLoader(ldc_train_dataset, shuffle=True, batch_size=8)
ldc_dev_dataloader = DataLoader(ldc_dev_dataset, batch_size=8)
ldc_test_dataloader = DataLoader(ldc_test_dataset, batch_size=8)

### Classification

In [34]:
save_path = '/homes/rpujari/scratch_ml/DARPA/ta2snapshot_saved_parameters/'

In [None]:
def train_emotion_detector(model, lr, train_dataset, train_dataloader, dev_dataloader, data_tag, num_epochs=5, save_path=save_path):
    optimizer = AdamW(model.parameters(), lr=lr)
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = get_scheduler(
        name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )
    
    progress_bar = tqdm(range(num_training_steps))

    model.train()
    best_dev_perf = -1
    loss_fn = nn.CrossEntropyLoss(weight=(1.0 / (1 + train_dataset["labels"].bincount())).to(device))

    for epoch in range(num_epochs):
        bnum = 0
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            loss = loss_fn(logits, batch["labels"])
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            bnum += 1
            if bnum % 30 == 0:
                model.eval()
                metric = load("accuracy")
                for dev_batch in dev_dataloader:
                    dev_batch = {k: v.to(device) for k, v in dev_batch.items()}
                    with torch.no_grad():
                        outputs = model(**dev_batch)

                    logits = outputs.logits
                    predictions = torch.argmax(logits, dim=-1)
                    metric.add_batch(predictions=predictions, references=dev_batch["labels"])
                perf = metric.compute()["accuracy"]
                print('Dev Accuracy: ', perf)
                if perf > best_dev_perf:
                    best_dev_perf = perf
                    model.save_pretrained(save_path + 'emotion_detector_' + data_tag)
                    print('Saving model parameters. Dev Accuracy: ', perf)
                model.train()

In [35]:
model = AutoModelForSequenceClassification.from_pretrained("hfl/chinese-roberta-wwm-ext", num_labels=9)
model = model.to(device)

# for param in model.base_model.parameters():
#     param.requires_grad = False

Downloading:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model che

In [37]:
lr = 1e-5
train_dataset = cped_train_dataset
train_dataloader = cped_train_dataloader
dev_dataloader = cped_dev_dataloader
data_tag = 'cped'
num_epochs=5

optimizer = AdamW(model.parameters(), lr=lr)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

progress_bar = tqdm(range(num_training_steps))

model.train()
best_dev_perf = -1
loss_fn = nn.CrossEntropyLoss(weight=(1.0 / (1 + train_dataset["labels"].bincount())).to(device))

for epoch in range(num_epochs):
    bnum = 0
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        print(outputs)
        break
    break

  0%|          | 0/3125 [00:00<?, ?it/s]

SequenceClassifierOutput(loss=tensor(2.2723, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[-0.4077, -0.4569,  0.4603,  0.4018,  0.1258,  0.0121,  0.0657,  0.5181,
         -0.0581],
        [-0.3557, -0.5392,  0.4419,  0.7706, -0.0335, -0.2437,  0.0247,  0.3945,
         -0.2664],
        [ 0.0958, -0.2526,  0.2757,  0.3855,  0.1569,  0.0443,  0.0351,  0.4226,
         -0.0539],
        [-0.0991, -0.2664,  0.5650,  0.2519,  0.1277, -0.1910,  0.3592,  0.6872,
         -0.1336],
        [-0.0346, -0.1258,  0.2635,  0.2501, -0.2344, -0.0460,  0.3277,  0.6777,
         -0.0840],
        [-0.4546, -0.2947,  0.0289, -0.0836, -0.0565, -0.0246, -0.0879,  0.3951,
          0.0313],
        [-0.0199, -0.4560,  0.1270, -0.1186, -0.2350, -0.2780,  0.3204,  0.4294,
         -0.1853],
        [-0.2098, -0.4391,  0.3435,  0.3893, -0.3568,  0.1692,  0.3937,  0.4986,
          0.1233]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [None]:
train_emotion_detector(model, 1e-5, cped_train_dataset, cped_train_dataloader, cped_dev_dataloader, 'cped')

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(save_path + "emotion_detector_cped", num_labels=9)
model = model.to(device)
train_emotion_detector(model, 1e-5, mpdd_train_dataset, mpdd_train_dataloader, mpdd_dev_dataloader, 'cped_mpdd')

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(save_path + "emotion_detector_cped_mpdd", num_labels=9)
model = model.to(device)
train_emotion_detector(model, 1e-5, ldc_train_dataset, ldc_train_dataloader, ldc_dev_dataloader, 'cped_mpdd_ldc')

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(save_path + "emotion_detector_cped_mpdd_ldc", num_labels=9)
model = model.to(device)

In [None]:
metric = load("accuracy")
model.eval()
preds = []
golds = []
for batch in ldc_train_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    preds.append(predictions)
    golds.append(batch["labels"])
    metric.add_batch(predictions=predictions, references=batch["labels"])
    probs = torch.nn.functional.softmax(logits, dim=-1)
metric.compute()

pred = torch.cat(preds, dim=0).cpu().data.numpy()
gold = torch.cat(golds, dim=0).cpu().data.numpy()

print(classification_report(gold, pred, labels = list(range(9)), target_names=plutchik, zero_division=0))

In [None]:
metric = load("accuracy")
model.eval()
preds = []
golds = []
for batch in ldc_dev_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    preds.append(predictions)
    golds.append(batch["labels"])
    metric.add_batch(predictions=predictions, references=batch["labels"])
    probs = torch.nn.functional.softmax(logits, dim=-1)
metric.compute()

pred = torch.cat(preds, dim=0).cpu().data.numpy()
gold = torch.cat(golds, dim=0).cpu().data.numpy()

print(classification_report(gold, pred, labels = list(range(9)), target_names=plutchik, zero_division=0))

In [None]:
metric = load("accuracy")
model.eval()
preds = []
golds = []
for batch in ldc_test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    preds.append(predictions)
    golds.append(batch["labels"])
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

pred = torch.cat(preds, dim=0).cpu().data.numpy()
gold = torch.cat(golds, dim=0).cpu().data.numpy()

print(classification_report(gold, pred, labels = list(range(9)), target_names=plutchik, zero_division=0))

### Eval Inference

#### Eval Data Loading

In [None]:
eval_fol_names = os.listdir('/homes/rpujari/scratch0_ml/ldc_data/')
eval_fol_names = [f for f in eval_fol_names if (f.startswith('eval-') and f != 'eval-zips') or '2022e22' in f]
print(eval_fol_names)

c = 0
eval_data = {'id': [], 'text': [], 'label': []}
eval_ids = {}
idx = 0
for fol_name in eval_fol_names:
    tdata = {}
    if os.path.exists(ldc_dpath + fol_name + '/data/text/'):
        fnames = os.listdir(ldc_dpath + fol_name + '/data/text/ltf/')
        for fname in fnames:
            tdata[fname[:-8]] = {}
            lsoup = BS(open(ldc_dpath + fol_name + '/data/text/ltf/' + fname).read(), 'xml')
            ttags = lsoup.find_all('SEG')
            c += len(ttags)
            tups = []
            for ttag in ttags:
                ot = ttag.find('ORIGINAL_TEXT').text
                sc = float(ttag.attrs['start_char'])
                ec = float(ttag.attrs['end_char'])
                tdata[fname[:-8]][(sc, ec)] = ot.strip()
                eval_ids[idx] = (fol_name, fname[:-8], (sc, ec))
                eval_data['id'].append(idx)
                eval_data['text'].append(ot.strip())
                eval_data['label'].append(0)
                idx += 1
print(c)

In [None]:
ldc_eval = Dataset.from_dict(eval_data)
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length",\
                     truncation=True, max_length=512)


tokenized_ldc_eval = ldc_eval.map(tokenize_function, batched=True)

tokenized_ldc_eval = tokenized_ldc_eval.remove_columns(["text"])
tokenized_ldc_eval = tokenized_ldc_eval.remove_columns(["id"])
tokenized_ldc_eval = tokenized_ldc_eval.rename_column("label", "labels")
tokenized_ldc_eval.set_format("torch")

ldc_eval_dataloader = DataLoader(tokenized_ldc_eval, batch_size=4)

In [None]:
save_path = '/homes/rpujari/scratch_ml/DARPA/ta2snapshot_saved_parameters/'
preds = []
probs = []
num_steps = len(ldc_eval_dataloader)
progress_bar = tqdm(range(num_steps))
with torch.no_grad():
    model = AutoModelForSequenceClassification.from_pretrained(save_path + "emotion_detector_cped_mpdd_ldc", num_labels=9)
    model = model.to(device)
    for batch in ldc_eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        preds.append(predictions)
        probs.append(logits)
        progress_bar.update(1)
        
ids = sorted(eval_ids.keys())

predictions_out = torch.cat(preds, dim=0)
print(predictions_out.size())
predictions_prob = F.softmax(torch.cat(probs, dim=0), dim=1)
print(predictions_prob.size())

In [None]:
out_fws = {}
out_cws = {}
for id_ in ids:
    if id_ < predictions_out.size(0):
        fol_name, fname, (sc, ec) = eval_ids[id_]
        if os.path.exists(ldc_dpath + fol_name + '/data/text/txt/' + fname + '.txt'):
            if fname not in out_fws:
                out_fw = open(ldc_dpath + 'outputs/text/ED/' + fname + '.tab', 'w')
                out_fws[fname] = out_fw
                out_cw = csv.writer(out_fw, delimiter='\t')
                out_cws[fname] = out_cw
                out_cw.writerow(['file_id', 'emotion', 'start', 'end', 'llr'])
            out_cw = out_cws[fname]
            if plutchik[predictions_out[id_]] != 'neutral':
                out_cw.writerow([fname, plutchik[predictions_out[id_]], sc, ec, float(torch.max(predictions_prob[id_, :]).data)])

for fname in out_fws:
    out_fws[fname].close()

#### NMAP data loading

In [None]:
eval_fol_names = ['ldc2022e18_v6']
print(eval_fol_names)

c = 0
eval_data = {'id': [], 'text': [], 'label': []}
eval_ids = {}
idx = 0
for fol_name in eval_fol_names:
    tdata = {}
    if os.path.exists(ldc_dpath + fol_name + '/source_data/text/'):
        fnames = os.listdir(ldc_dpath + fol_name + '/source_data/text/ltf/')
        for fname in fnames:
            tdata[fname[:-8]] = {}
            lsoup = BS(open(ldc_dpath + fol_name + '/source_data/text/ltf/' + fname).read(), 'xml')
            ttags = lsoup.find_all('SEG')
            c += len(ttags)
            tups = []
            for ttag in ttags:
                ot = ttag.find('ORIGINAL_TEXT').text
                sc = float(ttag.attrs['start_char'])
                ec = float(ttag.attrs['end_char'])
                tdata[fname[:-8]][(sc, ec)] = ot.strip()
                eval_ids[idx] = (fol_name, fname[:-8], (sc, ec))
                eval_data['id'].append(idx)
                eval_data['text'].append(ot.strip())
                eval_data['label'].append(0)
                idx += 1
print(c)

In [None]:
ldc_eval = Dataset.from_dict(eval_data)
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length",\
                     truncation=True, max_length=512)


tokenized_ldc_eval = ldc_eval.map(tokenize_function, batched=True)

tokenized_ldc_eval = tokenized_ldc_eval.remove_columns(["text"])
tokenized_ldc_eval = tokenized_ldc_eval.remove_columns(["id"])
tokenized_ldc_eval = tokenized_ldc_eval.rename_column("label", "labels")
tokenized_ldc_eval.set_format("torch")

ldc_eval_dataloader = DataLoader(tokenized_ldc_eval, batch_size=4)

In [None]:
save_path = '/homes/rpujari/scratch_ml/DARPA/ta2snapshot_saved_parameters/'
preds = []
probs = []
num_steps = len(ldc_eval_dataloader)
progress_bar = tqdm(range(num_steps))
with torch.no_grad():
    model = AutoModelForSequenceClassification.from_pretrained(save_path + "emotion_detector_cped_mpdd_ldc", num_labels=9)
    # model = AutoModelForSequenceClassification.from_pretrained(save_path + "emotion_detector_ldc_gpt_17_c_test", num_labels=17)
    model = model.to(device)
    for batch in ldc_eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        preds.append(predictions)
        probs.append(logits)
        progress_bar.update(1)
        
ids = sorted(eval_ids.keys())

predictions_out = torch.cat(preds, dim=0)
print(predictions_out.size())
predictions_prob = F.softmax(torch.cat(probs, dim=0), dim=1)
print(predictions_prob.size())

In [None]:
out_fws = {}
out_cws = {}
for id_ in ids:
    if id_ < predictions_out.size(0):
        fol_name, fname, (sc, ec) = eval_ids[id_]
        if os.path.exists(ldc_dpath + fol_name + '/source_data/text/ltf/' + fname + '.ltf.xml'):
            if fname not in out_fws:
                out_fw = open(ldc_dpath + 'outputs_v6/text/ED/' + fname + '.tab', 'w')
                out_fws[fname] = out_fw
                out_cw = csv.writer(out_fw, delimiter='\t')
                out_cws[fname] = out_cw
                out_cw.writerow(['file_id', 'emotion', 'start', 'end', 'llr'])
            out_cw = out_cws[fname]
            if plutchik[predictions_out[id_]] != 'neutral':
                out_cw.writerow([fname, plutchik[predictions_out[id_]], sc, ec, float(torch.max(predictions_prob[id_, :]).data)])

for fname in out_fws:
    out_fws[fname].close()

In [None]:
save_path = '/homes/rpujari/scratch_ml/DARPA/ta2snapshot_saved_parameters/'
status_preds = []
status_probs = []
num_steps = len(ldc_eval_dataloader)
progress_bar = tqdm(range(num_steps))
with torch.no_grad():
    status_model = AutoModelForSequenceClassification.from_pretrained(save_path + "norm_status_binary_ldc_text", num_labels=2)
    status_model = status_model.to(device)
    for batch in ldc_eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            status_outputs = status_model(**batch)

        status_logits = status_outputs.logits
        status_predictions = torch.argmax(status_logits, dim=-1)
        status_preds.append(status_predictions)
        status_probs.append(status_logits)
        progress_bar.update(1)
        
status_ids = sorted(eval_ids.keys())

status_predictions_out = torch.cat(status_preds, dim=0)
print(status_predictions_out.size())
status_predictions_prob = F.softmax(torch.cat(status_probs, dim=0), dim=1)
print(status_predictions_prob.size())

In [None]:
norm_codes = [101, 102, 103, 105, 104, 107, 106, 201, 202, 203, 204, 205, 206, 1001, 207, 208, 209]
status_codes = ['violate', 'adhere']

out_fws = {}
out_cws = {}
for id_ in ids:
    if id_ < predictions_out.size(0):
        fol_name, fname, (sc, ec) = eval_ids[id_]
        if os.path.exists(ldc_dpath + fol_name + '/source_data/text/ltf/' + fname + '.ltf.xml'):
            if fname not in out_fws:
                out_fw = open(ldc_dpath + 'outputs_v6/text/ND/' + fname + '.tab', 'w')
                out_fws[fname] = out_fw
                out_cw = csv.writer(out_fw, delimiter='\t')
                out_cws[fname] = out_cw
                out_cw.writerow(['file_id', 'norm', 'start', 'end', 'status', 'llr'])
            out_cw = out_cws[fname]
            if norm_codes[predictions_out[id_]] != 1001:
                stat = status_codes[status_predictions_out[id_]]
                out_cw.writerow([fname, norm_codes[predictions_out[id_]], sc, ec, stat, float(torch.max(predictions_prob[id_, :]).data)])

for fname in out_fws:
    out_fws[fname].close()

#### Norm Inference

In [None]:
save_path = '/homes/rpujari/scratch_ml/DARPA/ta2snapshot_saved_parameters/'
preds = []
probs = []
num_steps = len(ldc_eval_dataloader)
progress_bar = tqdm(range(num_steps))
with torch.no_grad():
    model = AutoModelForSequenceClassification.from_pretrained(save_path + "emotion_detector_ldc_gpt_17_c_test", num_labels=17)
    model = model.to(device)
    for batch in ldc_eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        preds.append(predictions)
        probs.append(logits)
        progress_bar.update(1)
        
ids = sorted(eval_ids.keys())

predictions_out = torch.cat(preds, dim=0)
print(predictions_out.size())
predictions_prob = F.softmax(torch.cat(probs, dim=0), dim=1)
print(predictions_prob.size())

In [None]:
all_norm_categories = ['Doing apology', 'Doing criticism', 'Doing greeting', 'Doing persuasion',\
                   'Doing request', 'Doing taking leave', 'Doing thanks', 'acknowledging',\
                   'expressing opinion', 'giving explanation', 'making clarification', 'making invitation',\
                   'making suggestion', 'none', 'offering reassurance', 'questioning', 'responding to request']

In [None]:
save_path = '/homes/rpujari/scratch_ml/DARPA/ta2snapshot_saved_parameters/'
status_preds = []
status_probs = []
num_steps = len(ldc_eval_dataloader)
progress_bar = tqdm(range(num_steps))
with torch.no_grad():
    status_model = AutoModelForSequenceClassification.from_pretrained(save_path + "norm_status_binary_ldc_text", num_labels=2)
    status_model = status_model.to(device)
    for batch in ldc_eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            status_outputs = status_model(**batch)

        status_logits = status_outputs.logits
        status_predictions = torch.argmax(status_logits, dim=-1)
        status_preds.append(status_predictions)
        status_probs.append(status_logits)
        progress_bar.update(1)
        
status_ids = sorted(eval_ids.keys())

status_predictions_out = torch.cat(status_preds, dim=0)
print(status_predictions_out.size())
status_predictions_prob = F.softmax(torch.cat(status_probs, dim=0), dim=1)
print(status_predictions_prob.size())

In [None]:
norm_codes = [101, 102, 103, 105, 104, 107, 106, 201, 202, 203, 204, 205, 206, 1001, 207, 208, 209]
status_codes = ['violate', 'adhere']

out_fws = {}
out_cws = {}
for id_ in ids:
    if id_ < predictions_out.size(0):
        fol_name, fname, (sc, ec) = eval_ids[id_]
        if os.path.exists(ldc_dpath + fol_name + '/data/text/txt/' + fname + '.txt'):
            if fname not in out_fws:
                out_fw = open(ldc_dpath + 'outputs/text/ND/' + fname + '.tab', 'w')
                out_fws[fname] = out_fw
                out_cw = csv.writer(out_fw, delimiter='\t')
                out_cws[fname] = out_cw
                out_cw.writerow(['file_id', 'norm', 'start', 'end', 'status', 'llr'])
            out_cw = out_cws[fname]
            if norm_codes[predictions_out[id_]] != 1001:
                stat = status_codes[status_predictions_out[id_]]
                out_cw.writerow([fname, norm_codes[predictions_out[id_]], sc, ec, stat, float(torch.max(predictions_prob[id_, :]).data)])

for fname in out_fws:
    out_fws[fname].close()

#### ChangePoint Inference

In [None]:
eval_fol_names = os.listdir('/homes/rpujari/scratch0_ml/ldc_data/')
eval_fol_names = [f for f in eval_fol_names if (f.startswith('eval-') and f != 'eval-zips') or '2022e22' in f]
print(eval_fol_names)

c = 0
idx = 0
eval_cp_data = {}
for fol_name in eval_fol_names:
    tdata = {}
    if os.path.exists(ldc_dpath + fol_name + '/data/text/'):
        fnames = os.listdir(ldc_dpath + fol_name + '/data/text/ltf/')
        for fname in fnames:
            tdata[fname[:-8]] = {}
            lsoup = BS(open(ldc_dpath + fol_name + '/data/text/ltf/' + fname).read(), 'xml')
            ttags = lsoup.find_all('SEG')
            c += len(ttags)
            for ttag in ttags:
                ot = ttag.find('ORIGINAL_TEXT').text
                sc = float(ttag.attrs['start_char'])
                ec = float(ttag.attrs['end_char'])
                tdata[fname[:-8]][(sc, ec)] = (idx, ot.strip())
                idx += 1
    eval_cp_data[fol_name] = tdata
print(c)

In [None]:
cp_data = {'id': [], 'text_a': [], 'text_b': []}
cp_ids = {}

dp_num = 0
for fol_name in eval_cp_data:
    for fname in eval_cp_data[fol_name]:
        seg_bos = list(eval_cp_data[fol_name][fname].keys())
        seg_bos = sorted(seg_bos, key=lambda x:int(x[0]))
        l = len(seg_bos)
        if l > 10:
            st = 4
            en = l - 5
            for idx in range(st, en):
                bs_ = ' '.join([eval_cp_data[fol_name][fname][bo][1] for bo in seg_bos[idx-4:idx]])[-50:]
                as_ = ' '.join([eval_cp_data[fol_name][fname][bo][1] for bo in seg_bos[idx:idx+4]])[:50]
                cp_data['id'].append(dp_num)
                cp_data['text_b'].append(bs_)
                cp_data['text_a'].append(as_)
                cp_ids[dp_num] = (fol_name, fname, idx)
                dp_num += 1
print(dp_num)

In [None]:
eval_cp_dset = Dataset.from_dict(cp_data)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")

def tokenize_function(examples):
    return tokenizer(examples['text_b'], examples['text_a'], padding="max_length",\
                     truncation=True, max_length=512)

tokenized_eval_cp = eval_cp_dset.map(tokenize_function, batched=True)

In [None]:
tokenized_eval_cp = tokenized_eval_cp.remove_columns(["text_a"])
tokenized_eval_cp = tokenized_eval_cp.remove_columns(["text_b"])
tokenized_eval_cp = tokenized_eval_cp.remove_columns(["id"])
tokenized_eval_cp.set_format("torch")

In [None]:
eval_cp_dataloader = DataLoader(tokenized_eval_cp, batch_size=4)

In [None]:
save_path = '/homes/rpujari/scratch_ml/DARPA/ta2snapshot_saved_parameters/'
preds = []
probs = []
num_steps = len(eval_cp_dataloader)
progress_bar = tqdm(range(num_steps))
with torch.no_grad():
    model = AutoModelForSequenceClassification.from_pretrained(save_path + "change_point_ldc_text", num_labels=2)
    model = model.to(device)
    for batch in eval_cp_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        preds.append(predictions)
        probs.append(logits)
        progress_bar.update(1)
        
ids = sorted(eval_ids.keys())

predictions_out = torch.cat(preds, dim=0)
print(predictions_out.size())
predictions_prob = F.softmax(torch.cat(probs, dim=0), dim=1)
print(predictions_prob.size())

In [None]:
ids = sorted(cp_ids.keys())

In [None]:
out_fws = {}
out_cws = {}
p = 0
n = 0
counts = {}
fname_cps = {}
for id_ in ids:
    if id_ < predictions_out.size(0):
        fol_name, fname, idx = cp_ids[id_]
        seg_bos = list(eval_cp_data[fol_name][fname].keys())
        seg_bos = sorted(seg_bos, key=lambda x:int(x[0]))
        ts = seg_bos[idx][0]
        if fname not in counts:
            counts[fname] = 0
        if os.path.exists(ldc_dpath + fol_name + '/data/text/txt/' + fname + '.txt'):
            if predictions_out[id_] == 1:
                if fname not in fname_cps:
                    fname_cps[fname] = []
                counts[fname] += 1
                fname_cps[fname].append((fname, ts, float(torch.max(predictions_prob[id_, :]).data)))
            else:
                n += 1

for fname in fname_cps:
    cp_list = fname_cps[fname]
    cp_list = sorted(cp_list, key=lambda x:x[2], reverse=True)
    out_fw = open(ldc_dpath + 'outputs/text/CD/' + fname + '.tab', 'w')
    out_cw = csv.writer(out_fw, delimiter='\t')
    out_cw.writerow(['file_id', 'timestamp', 'llr'])
    for i in range(3):
        if i < len(cp_list):
            row = cp_list[i]
            out_cw.writerow([row[0], row[1], row[2]])
            p += 1
    out_fw.close()


In [None]:
print(p)