# Imports and Loading Data

In [2]:
!pip3 install pytorch-pretrained-bert
!pip3 install simplediff
!pip3 install tensorboardX

Collecting pytorch-pretrained-bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |██▋                             | 10kB 16.2MB/s eta 0:00:01[K     |█████▎                          | 20kB 1.8MB/s eta 0:00:01[K     |████████                        | 30kB 2.2MB/s eta 0:00:01[K     |██████████▋                     | 40kB 2.5MB/s eta 0:00:01[K     |█████████████▎                  | 51kB 2.0MB/s eta 0:00:01[K     |███████████████▉                | 61kB 2.3MB/s eta 0:00:01[K     |██████████████████▌             | 71kB 2.5MB/s eta 0:00:01[K     |█████████████████████▏          | 81kB 2.7MB/s eta 0:00:01[K     |███████████████████████▉        | 92kB 2.9MB/s eta 0:00:01[K     |██████████████████████████▌     | 102kB 2.7MB/s eta 0:00:01[K     |█████████████████████████████▏  | 112kB 2.7MB/s eta 0:00:01[K     |██████████████████████

In [81]:
#from pytorch_pretrained_bert.modeling import PreTrainedBertModel, BertModel, BertSelfAttention
import pytorch_pretrained_bert.modeling as modeling
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import copy
from collections import defaultdict
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

from tqdm import tqdm
import sys

import pickle
import os
from pytorch_pretrained_bert.modeling import BertForTokenClassification
from torch.nn import CrossEntropyLoss
from tensorboardX import SummaryWriter
import argparse
import sklearn.metrics as metrics

In [4]:
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.modeling import BertModel, BertSelfAttention
from pytorch_pretrained_bert.modeling import BertPreTrainedModel

## Data Study
Raw data files contain the following: 


318427508	

in april 2009 a brazilian human rights group , torture never again , awarded the five its chico mendes medal , because their rights had been violated .

in april 2009 a brazilian human rights group , torture never again , awarded the five its chico mendes medal , under the pre ##text that their rights had been violated .

in april 2009 a brazilian human rights group, torture never again, awarded the five its chico mendes medal, because their rights had been violated.

in april 2009 a brazilian human rights group, torture never again, awarded the five its chico mendes medal, under the pretext that their rights had been violated.

ADP NOUN NUM DET ADJ ADJ NOUN NOUN PUNCT NOUN ADV ADV PUNCT VERB DET NUM ADJ NOUN NOUN NOUN PUNCT ADP ADJ NOUN VERB VERB VERB PUNCT	

prep pobj nummod det amod amod compound nsubj punct appos neg advmod punct ROOT det nummod poss compound compound dobj punct mark poss nsubjpass aux auxpass advcl punct

--------------------

235640083	

the 51 day stand ##off and ensuing murder of 76 men , women , and children - - the branch david ##ians - - in wa ##co , texas .	

the 51 day stand ##off and ensuing deaths of 76 men , women , and children - - the branch david ##ians - - in wa ##co , texas .

the 51 day standoff and ensuing murder of 76 men, women, and children--the branch davidians--in waco, texas.	

the 51 day standoff and ensuing deaths of 76 men, women, and children--the branch davidians--in waco, texas.	

DET NUM NOUN NOUN NOUN CCONJ VERB NOUN ADP NUM NOUN PUNCT NOUN PUNCT CCONJ NOUN PUNCT PUNCT DET NOUN NOUN NOUN PUNCT PUNCT PART NOUN NOUN PUNCT NOUN PUNCT	

det nummod compound nsubj nsubj cc amod conj prep nummod pobj punct conj punct cc conj punct punct det nsubj conj conj punct punct prep pobj pobj punct ROOT punct



In [5]:
# Universal Dependencies Scheme used in all languages trained on Universal Dependency Corpora
# some english dependency labels use CLEAR style by ClearNLP
# SYNTACTIC DEPENDENCIES
RELATIONS = [
  'det', # determiner (the, a)
  'amod', # adjectival modifier
  'nsubj', # nominal subject
  'prep', # prepositional modifier
  'pobj', # object of preposition
  'ROOT', # root
  'attr', # attribute
  'punct', # punctuation
  'advmod', # adverbial modifier
  'compound', # compound
  'acl', # clausal modifier of noun (adjectivial clause)
  'agent', # agent
  'aux', # auxiliary
  'ccomp', # clausal complement
  'dobj', # direct object
  'cc', # coordinating conjunction 
  'conj', # conjunct
  'appos', # appositional 
  'nsubjpass', # nsubjpass
  'auxpass', # auxiliary (passive)
  'poss', # poss
  'nummod', # numeric modifier
  'nmod', # nominal modifier
  'relcl', # relative clause modifier
  'mark', # marker
  'advcl', # adverbial clause modifier
  'pcomp', # complement of preposition
  'npadvmod', # noun phrase as adverbial modifier
  'preconj', # pre-correlative conjunction
  'neg', # negation modifier
  'xcomp', # open clausal complement
  'csubj', # clausal subject
  'prt', # particle
  'parataxis', # parataxis
  'expl', # expletive
  'case', # case marking
  'acomp', # adjectival complement
  'predet', # ??? 
  'quantmod', # modifier of quantifier
  'dep', # unspecified dependency
  'oprd', # object predicate
  'intj', # interjection
  'dative', # dative
  'meta', # meta modifier
  'csubjpass', # clausal subject (passive)
  '<UNK>' # unknown
]

REL2ID = {x: i for i,x in enumerate(RELATIONS)}

# PARTS OF SPEECH
POS_TAGS = [
  'DET', # determiner (a, an, the)
  'ADJ', # adjective (big, old, green, first)
  'NOUN', # noun (girl, cat, tree)
  'ADP', # adposition (in, to, during)
  'NUM', # numeral (1, 2017, one, IV)
  'VERB', # verb (runs, running, eat, ate)
  'PUNCT', # punctuation (., (, ), ?)
  'ADV', # adverb (very, tomorrow, down)
  'PART', # particle ('s, not)
  'CCONJ', # coordinating conjunction (and, or, but)
  'PRON', # pronoun(I, you, he, she)
  'X', # other (fhefkoskjsdods)
  'INTJ', # interjection (hello, psst, ouch, bravo)
  'PROPN', # proper noun (Mary, John, London, HBO) 
  'SYM', # symbol ($, %, +, -, =)
  '<UNK>' # unknown
]

POS2ID = {x: i for i, x in enumerate(POS_TAGS)}
POS2ID

{'<UNK>': 15,
 'ADJ': 1,
 'ADP': 3,
 'ADV': 7,
 'CCONJ': 9,
 'DET': 0,
 'INTJ': 12,
 'NOUN': 2,
 'NUM': 4,
 'PART': 8,
 'PRON': 10,
 'PROPN': 13,
 'PUNCT': 6,
 'SYM': 14,
 'VERB': 5,
 'X': 11}

In [6]:
EDIT_TYPE2ID = {'0':0, '1':1, 'mask':2}

# they will add up to 1
def softmax(x, axis=None):
  x=x-x.max(axis=axis, keepdims=True)
  y= np.exp(x)
  return y/y.sum(axis=axis, keepdims=True)

In [34]:
# not sure what this function does
def get_tok_labels(s_diff):
  pre_tok_labels = []
  post_tok_labels = []
  for tag, chunk in s_diff:
    if tag == '=':
      pre_tok_labels += [0] * len(chunk)
      post_tok_labels += [0] * len(chunk)
    elif tag == '-':
      pre_tok_labels += [1] * len(chunk) # 1 in pre if word deleted in post
    elif tag == '+':
      post_tok_labels += [1] * len(chunk) # 1 in post if word added in post
    else: 
      pass
  return pre_tok_labels, post_tok_labels 
  # returns returns list of 0s, list of 1s for both pre and post edit sentences


In [9]:
# create a randomly sampled noisy version of a sequence
# drops out every word in the sentence with a probability p_wd (drop_prob)
# slightly shuffle the input sentence 
def noise_seq(seq, drop_prob=0.25, shuf_dist=3, drop_set=None, keep_bigrams=False):
    # from https://arxiv.org/pdf/1711.00043.pdf
    def perm(i):
        return i[0] + (shuf_dist + 1) * np.random.random()
    
    if drop_set == None:
        dropped_seq = [x for x in seq if np.random.random() > drop_prob]
    else:
        dropped_seq = [x for x in seq if not (x in drop_set and np.random.random() < drop_prob)]

    if keep_bigrams:
        i = 0
        original = ' '.join(seq)
        tmp = []
        while i < len(dropped_seq)-1:
            if ' '.join(dropped_seq[i : i+2]) in original:
                tmp.append(dropped_seq[i : i+2])
                i += 2
            else:
                tmp.append([dropped_seq[i]])
                i += 1

        dropped_seq = tmp

    # global shuffle
    if shuf_dist == -1:
        shuffle(dropped_seq)
    # local shuffle
    elif shuf_dist > 0:
        dropped_seq = [x for _, x in sorted(enumerate(dropped_seq), key=perm)]
    # shuf_dist of 0 = no shuffle

    if keep_bigrams:
        dropped_seq = [z for y in dropped_seq for z in y]
    
    return dropped_seq

In [25]:
# pad end of a sequence so it is the max sequence length
def pad(id_arr, pad_idx):
  max_seq_len = 80
  return id_arr + ([pad_idx] * (max_seq_len - len(id_arr)))

In [99]:
# get examples of data!
def get_examples(data_path, tok2id, max_seq_len, 
                 noise=False, add_del_tok=False,
                 categories_path=None):
    #global REL2ID
    #global POS2ID
    #global EDIT_TYPE2ID

    # ARGS.drop_words is not None:
    #    drop_set = set([l.strip() for l in open(ARGS.drop_words)])
    #else:
    #    drop_set = None

    skipped = 0 
    out = defaultdict(list)
    #print(out)
    #input()
    '''if categories_path is not None:
        category_fp = open(categories_path)
        next(category_fp) # ignore header
        revid2topic = {
            l.strip().split(',')[0]: [float(x) for x in l.strip().split(',')[1:]]
            for l in category_fp
        }'''
    for i, (line) in enumerate(tqdm(open(data_path))):
        parts = line.strip().split('\t')
        #print(parts, len(parts))
        #input("continue....")

        # if there pos/rel info
        if len(parts) == 7:
            [revid, pre, post, _, _, pos, rels] = parts
        # no pos/rel info
        elif len(parts) == 5:
            [revid, pre, post, _, _] = parts
            pos = ' '.join(['<UNK>'] * len(pre.strip().split()))
            rels = ' '.join(['<UNK>'] * len(pre.strip().split()))
        # broken line
        else:
            skipped += 1
            continue

        # break up tokens
        tokens = pre.strip().split()
        post_tokens = post.strip().split()
        rels = rels.strip().split()
        pos = pos.strip().split()

        # get diff + binary diff masks
        tok_diff = diff(tokens, post_tokens)
        pre_tok_labels, post_tok_labels = get_tok_labels(tok_diff)
                   
        # make sure everything lines up    
        if len(tokens) != len(pre_tok_labels) \
            or len(tokens) != len(rels) \
            or len(tokens) != len(pos) \
            or len(post_tokens) != len(post_tok_labels):
            skipped += 1
            continue

        # leave room in the post for start/stop and possible category/class token
        if len(tokens) > max_seq_len - 1 or len(post_tokens) > max_seq_len - 1:
            skipped += 1
            continue

        # category info if provided
        # TODO -- if provided but not in diyi's data, we fill with random...is that ok?
        '''if categories_path is not None and revid in revid2topic:
            categories = revid2topic[revid]
        else:'''
        categories = np.random.uniform(size=43)   # 43 = number of categories
        categories = categories / sum(categories) # normalize

        '''if ARGS.category_input:
            category_id = np.argmax(categories)
            tokens = ['[unused%d]' % category_id] + tokens
            pre_tok_labels = [EDIT_TYPE2ID['mask']] + pre_tok_labels
            post_tok_labels = [EDIT_TYPE2ID['mask']] + post_tok_labels
        '''

        # add start + end symbols to post in/out
        post_input_tokens = ['行'] + post_tokens
        post_output_tokens = post_tokens + ['止'] 

        # shuffle + convert to ids + pad
        try:
            if noise:
                pre_toks = noise_seq(
                    tokens[:], 
                    drop_prob=ARGS.noise_prob, 
                    shuf_dist=ARGS.shuf_dist,
                    drop_set=drop_set,
                    keep_bigrams=ARGS.keep_bigrams)
            else:
                pre_toks = tokens

            pre_ids = pad([tok2id[x] for x in pre_toks], 0)
            post_in_ids = pad([tok2id[x] for x in post_input_tokens], 0)
            post_out_ids = pad([tok2id[x] for x in post_output_tokens], 0)
            pre_tok_label_ids = pad(pre_tok_labels, EDIT_TYPE2ID['mask'])
            post_tok_label_ids = pad(post_tok_labels, EDIT_TYPE2ID['mask'])
            rel_ids = pad([REL2ID.get(x, REL2ID['<UNK>']) for x in rels], 0)
            pos_ids = pad([POS2ID.get(x, POS2ID['<UNK>']) for x in pos], 0)
        except KeyError:
            # TODO FUCK THIS ENCODING BUG!!!
            skipped += 1
            continue

        input_mask = pad([0] * len(tokens), 1)
        pre_len = len(tokens)

        out['pre_ids'].append(pre_ids)
        out['pre_masks'].append(input_mask)
        out['pre_lens'].append(pre_len)
        out['post_in_ids'].append(post_in_ids)
        out['post_out_ids'].append(post_out_ids)
        out['pre_tok_label_ids'].append(pre_tok_label_ids)
        out['post_tok_label_ids'].append(post_tok_label_ids)
        out['rel_ids'].append(rel_ids)
        out['pos_ids'].append(pos_ids)
        out['categories'].append(categories)

    print('SKIPPED ', skipped)
    return out


In [109]:
# get data loader for data in data_path
def get_dataloader(data_path, tok2id, batch_size, 
                   pickle_path=None, test=False, noise=False, add_del_tok=False, 
                   categories_path=None, sort_batch=True):
    #global ARGS

    def collate(data):
        if sort_batch:
            # sort by length for packing/padding
            data.sort(key=lambda x: x[2], reverse=True)
        # group by datatype
        [
            src_id, src_mask, src_len, 
            post_in_id, post_out_id, 
            pre_tok_label, post_tok_label,
            rel_ids, pos_ids, categories
        ] = [torch.stack(x) for x in zip(*data)]

        # cut off at max len of this batch for unpacking/repadding
        max_len = max(src_len)
        data = [
            src_id[:, :max_len], src_mask[:, :max_len], src_len, 
            post_in_id[:, :max_len+10], post_out_id[:, :max_len+10],    # +10 for wiggle room
            pre_tok_label[:, :max_len], post_tok_label[:, :max_len+10], # +10 for post_toks_labels too (it's just gonna be matched up with post ids)
            rel_ids[:, :max_len], pos_ids[:, :max_len], categories
        ]

        return data

    if pickle_path is not None and os.path.exists(pickle_path):
        print("pickle file exists!")
        examples = pickle.load(open(pickle_path, 'rb'))
    else:
        examples = get_examples(
            data_path=data_path, 
            tok2id=tok2id,
            max_seq_len=80, #ARGS.max_seq_len,
            noise=False, #noise,
            add_del_tok=False, #add_del_tok,
            categories_path=None)#categories_path)

        pickle.dump(examples, open(pickle_path, 'wb'))

    data = TensorDataset(
        torch.tensor(examples['pre_ids'], dtype=torch.long),
        torch.tensor(examples['pre_masks'], dtype=torch.uint8), # byte for masked_fill()
        torch.tensor(examples['pre_lens'], dtype=torch.long),
        torch.tensor(examples['post_in_ids'], dtype=torch.long),
        torch.tensor(examples['post_out_ids'], dtype=torch.long),
        torch.tensor(examples['pre_tok_label_ids'], dtype=torch.float),  # for compartin to enrichment stuff
        torch.tensor(examples['post_tok_label_ids'], dtype=torch.float),  # for loss multiplying
        torch.tensor(examples['rel_ids'], dtype=torch.long),
        torch.tensor(examples['pos_ids'], dtype=torch.long),
        torch.tensor(examples['categories'], dtype=torch.float))

    dataloader = DataLoader(
        data,
        sampler=(SequentialSampler(data) if test else RandomSampler(data)),
        collate_fn=collate,
        batch_size=batch_size)

    return dataloader, len(examples['pre_ids'])

In [12]:
#from shared.args import ARGS 
#from shared.constants import CUDA 
#import seq2seq.model as seq2seq_model
CUDA = (torch.cuda.device_count() > 0)

In [13]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [87]:
## Update the experiments directory
DATA_DIRECTORY = '/content/drive/Shared drives/EC463 464 Senior Design Project/data/'
LEXICON_DIRECTORY = DATA_DIRECTORY + 'lexicons/'
PRYZANT_DATA = DATA_DIRECTORY + 'pryzant_data/WNC/'
#IMPORTS = 
training_data = PRYZANT_DATA + 'biased.word.train'
testing_data = PRYZANT_DATA + 'biased.word.test'
categories_file = PRYZANT_DATA + 'revision_topics.csv'
pickle_directory = '/content/drive/Shared drives/EC463 464 Senior Design Project/data/pickle_data/'
cache_dir = DATA_DIRECTORY + 'cache/'
model_save_dir = '/content/drive/Shared drives/EC463 464 Senior Design Project/models/'

In [15]:
# Load imports
!cp '/content/drive/Shared drives/EC463 464 Senior Design Project/imports/data.py' .
import data

In [16]:
os.getcwd()

'/content'

In [17]:
print('LOADING DATA...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', os.getcwd() + '/cache')
tok2id = tokenizer.vocab
tok2id['<del>'] = len(tok2id)

  0%|          | 0/231508 [00:00<?, ?B/s]

LOADING DATA...


100%|██████████| 231508/231508 [00:00<00:00, 4211677.71B/s]


In [110]:
train_dataloader, num_train_examples = get_dataloader(
    data_path=training_data,
    tok2id=tok2id,
    batch_size=32,
    pickle_path=pickle_directory + 'train_data4.p',
    categories_path=None #categories_file
  )

eval_dataloader, num_eval_examples = get_dataloader(
    data_path=testing_data,
    tok2id=tok2id,
    batch_size=32,
    pickle_path=pickle_directory + 'test_data4.p',
    categories_path=None #categories_file
  )

print(num_train_examples, num_eval_examples)







0it [00:00, ?it/s][A[A[A[A[A[A





437it [00:00, 4362.54it/s][A[A[A[A[A[A





980it [00:00, 4635.80it/s][A[A[A[A[A[A





1525it [00:00, 4852.38it/s][A[A[A[A[A[A





1978it [00:00, 4748.80it/s][A[A[A[A[A[A





2573it [00:00, 5053.23it/s][A[A[A[A[A[A





3188it [00:00, 5260.33it/s][A[A[A[A[A[A





3759it [00:00, 5386.63it/s][A[A[A[A[A[A





4269it [00:00, 5288.75it/s][A[A[A[A[A[A





4878it [00:00, 5505.80it/s][A[A[A[A[A[A





5468it [00:01, 5617.76it/s][A[A[A[A[A[A





6023it [00:01, 5352.86it/s][A[A[A[A[A[A





6556it [00:01, 5258.44it/s][A[A[A[A[A[A





7081it [00:01, 5049.52it/s][A[A[A[A[A[A





7587it [00:01, 5049.29it/s][A[A[A[A[A[A





8203it [00:01, 5336.53it/s][A[A[A[A[A[A





8815it [00:01, 5548.00it/s][A[A[A[A[A[A





9376it [00:01, 5172.02it/s][A[A[A[A[A[A





9918it [00:01, 5241.62it/s][A[A[A[A[A[A





10449it [00:02, 3323.23it/s][A[A[

SKIPPED  1503








0it [00:00, ?it/s][A[A[A[A[A[A





1000it [00:00, 5690.19it/s]


SKIPPED  32
52300 968


# Define Model

In [102]:
# BERT initialization params
config = 'bert-base-uncased'
cls_num_labels = 43
tok_num_labels = 3
tok2id = tok2id

class BertForMultitask(BertPreTrainedModel):

    def __init__(self, config, cls_num_labels=2, tok_num_labels=2, tok2id=None):
        super(BertForMultitask, self).__init__(config)
        self.bert = BertModel(config)

        self.cls_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.cls_classifier = nn.Linear(config.hidden_size, cls_num_labels)
        
        self.tok_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.tok_classifier = nn.Linear(config.hidden_size, tok_num_labels)
        
        self.apply(self.init_bert_weights)


    def forward(self, input_ids, token_type_ids=None, attention_mask=None, 
        labels=None, rel_ids=None, pos_ids=None, categories=None, pre_len=None):
        global ARGS
        sequence_output, pooled_output = self.bert(
            input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

        cls_logits = self.cls_classifier(pooled_output)
        cls_logits = self.cls_dropout(cls_logits)

        # NOTE -- dropout is after proj, which is non-standard
        tok_logits = self.tok_classifier(sequence_output)
        tok_logits = self.tok_dropout(tok_logits)

        return cls_logits, tok_logits

In [103]:
# define model!!
model = BertForMultitask.from_pretrained(
    'bert-base-uncased',
    cls_num_labels=cls_num_labels,
    tok_num_labels=tok_num_labels, 
    cache_dir=cache_dir,
    tok2id=tok2id)


In [111]:
def build_optimizer(model, num_train_steps, learning_rate):
    #global ARGS

    '''if ARGS.tagger_from_debiaser:
        parameters = list(model.cls_classifier.parameters()) + list(
            model.tok_classifier.parameters())
        parameters = list(filter(lambda p: p.requires_grad, parameters))
        return optim.Adam(parameters, lr=ARGS.learning_rate)
    else:'''
    param_optimizer = list(model.named_parameters())
    param_optimizer = list(filter(lambda name_param: name_param[1].requires_grad, param_optimizer))
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]
    return BertAdam(optimizer_grouped_parameters,
                          lr=learning_rate,
                          warmup=0.1,
                          t_total=num_train_steps)


def build_loss_fn(debias_weight=None):
    global ARGS
    
    if debias_weight is None:
        debias_weight = 1 # default #ARGS.debias_weight
    
    weight_mask = torch.ones(3) #ARGS.num_tok_labels)
    weight_mask[-1] = 0

    if CUDA:
        weight_mask = weight_mask.cuda()
        criterion = CrossEntropyLoss(weight=weight_mask).cuda()
        per_tok_criterion = CrossEntropyLoss(weight=weight_mask, reduction='none').cuda()
    else:
        criterion = CrossEntropyLoss(weight=weight_mask)
        per_tok_criterion = CrossEntropyLoss(weight=weight_mask, reduction='none')


    def cross_entropy_loss(logits, labels, apply_mask=None):
        return criterion(
            logits.contiguous().view(-1, 3), #ARGS.num_tok_labels), 
            labels.contiguous().view(-1).type('torch.cuda.LongTensor' if CUDA else 'torch.LongTensor'))

    def weighted_cross_entropy_loss(logits, labels, apply_mask=None):
        # weight mask = where to apply weight (post_tok_label_id from the batch)
        weights = apply_mask.contiguous().view(-1)
        weights = ((debias_weight - 1) * weights) + 1.0

        per_tok_losses = per_tok_criterion(
            logits.contiguous().view(-1, 3), # ARGS.num_tok_labels), 
            labels.contiguous().view(-1).type('torch.cuda.LongTensor' if CUDA else 'torch.LongTensor'))
        per_tok_losses = per_tok_losses * weights

        loss = torch.mean(per_tok_losses[torch.nonzero(per_tok_losses)].squeeze())

        return loss

    if debias_weight == 1.0:
        loss_fn = cross_entropy_loss
    else:
        loss_fn = weighted_cross_entropy_loss

    return loss_fn

In [105]:
epochs = 4
train_batch_size = 32 
learning_rate = 3e-5
optimizer = build_optimizer(
    model, int((num_train_examples * epochs) / train_batch_size),
    learning_rate)
loss_fn = build_loss_fn()

In [88]:
from tensorboardX import SummaryWriter
writer = SummaryWriter(model_save_dir)

In [113]:
def to_probs(logits, lens):
    per_tok_probs = softmax(np.array(logits)[:, :, :2], axis=2)
    pos_scores = per_tok_probs[:, :, -1]
    
    out = []
    for score_seq, l in zip(pos_scores, lens):
        out.append(score_seq[:l].tolist())
    return out

def run_inference(model, eval_dataloader, loss_fn, tokenizer):
    #global ARGS

    out = {
        'input_toks': [],
        'post_toks': [],

        'tok_loss': [],
        'tok_logits': [],
        'tok_probs': [],
        'tok_labels': [],

        'labeling_hits': []
    }

    for step, batch in enumerate(tqdm(eval_dataloader)):
        #if False and step > 2:
        #    continue

        if CUDA:
            batch = tuple(x.cuda() for x in batch)

        ( 
            pre_id, pre_mask, pre_len, 
            post_in_id, post_out_id, 
            tok_label_id, _,
            rel_ids, pos_ids, categories
        ) = batch

        with torch.no_grad():
            _, tok_logits = model(pre_id, attention_mask=1.0-pre_mask,
                rel_ids=rel_ids, pos_ids=pos_ids, categories=categories,
                pre_len=pre_len)
            tok_loss = loss_fn(tok_logits, tok_label_id, apply_mask=tok_label_id)
        out['input_toks'] += [tokenizer.convert_ids_to_tokens(seq) for seq in pre_id.cpu().numpy()]
        out['post_toks'] += [tokenizer.convert_ids_to_tokens(seq) for seq in post_in_id.cpu().numpy()]
        out['tok_loss'].append(float(tok_loss.cpu().numpy()))
        logits = tok_logits.detach().cpu().numpy()
        labels = tok_label_id.cpu().numpy()
        out['tok_logits'] += logits.tolist()
        out['tok_labels'] += labels.tolist()
        out['tok_probs'] += to_probs(logits, pre_len)
        out['labeling_hits'] += tag_hits(logits, labels)

    return out

In [115]:
def train_for_epoch(model, train_dataloader, loss_fn, optimizer):
    global ARGS
    
    losses = []
    
    for step, batch in enumerate(tqdm(train_dataloader)):
        #if ARGS.debug_skip and step > 2:
        #    continue
    
        if CUDA:
            batch = tuple(x.cuda() for x in batch)
        ( 
            pre_id, pre_mask, pre_len, 
            post_in_id, post_out_id, 
            tok_label_id, _,
            rel_ids, pos_ids, categories
        ) = batch
        _, tok_logits = model(pre_id, attention_mask=1.0-pre_mask,
            rel_ids=rel_ids, pos_ids=pos_ids, categories=categories,
            pre_len=pre_len)
        loss = loss_fn(tok_logits, tok_label_id, apply_mask=tok_label_id)
        loss.backward()
        optimizer.step()
        model.zero_grad()

        losses.append(loss.detach().cpu().numpy())

    return losses

def is_ranking_hit(probs, labels, top=1):
    global ARGS
    
    # get rid of padding idx
    [probs, labels] = list(zip(*[(p, l)  for p, l in zip(probs, labels) if l != 3 - 1 ]))
    probs_indices = list(zip(np.array(probs)[:, 1], range(len(labels))))
    [_, top_indices] = list(zip(*sorted(probs_indices, reverse=True)[:top]))
    if sum([labels[i] for i in top_indices]) > 0:
        return 1
    else:
        return 0

def tag_hits(logits, tok_labels, top=1):
    #global ARGS
    
    probs = softmax(np.array(logits)[:, :, : 3 - 1], axis=2)

    hits = [
        is_ranking_hit(prob_dist, tok_label, top=top) 
        for prob_dist, tok_label in zip(probs, tok_labels)
    ]
    return hits

In [None]:
# TRAIN MODEL!!
# run_inference
# train_for_epoch

print('INITIAL EVAL...')
model.eval()
results = run_inference(model, eval_dataloader, loss_fn, tokenizer)
writer.add_scalar('eval/tok_loss', np.mean(results['tok_loss']), 0)
writer.add_scalar('eval/tok_acc', np.mean(results['labeling_hits']), 0)

print('TRAINING...')
model.train()
for epoch in range(epochs):
    print('STARTING EPOCH ', epoch)
    losses = train_for_epoch(model, train_dataloader, loss_fn, optimizer)
    writer.add_scalar('train/loss', np.mean(losses), epoch + 1)

        # eval
    print('EVAL...')
    model.eval()
    results = run_inference(model, eval_dataloader, loss_fn, tokenizer)
    writer.add_scalar('eval/tok_loss', np.mean(results['tok_loss']), epoch + 1)
    writer.add_scalar('eval/tok_acc', np.mean(results['labeling_hits']), epoch + 1)

    model.train()

    print('SAVING...')
    torch.save(model.state_dict(), model_save_dir + 'model_%d.ckpt' % epoch)








  0%|          | 0/31 [00:00<?, ?it/s][A[A[A[A[A[A[A

INITIAL EVAL...









  3%|▎         | 1/31 [00:07<03:43,  7.46s/it][A[A[A[A[A[A[A






  6%|▋         | 2/31 [00:14<03:33,  7.37s/it][A[A[A[A[A[A[A






 10%|▉         | 3/31 [00:22<03:30,  7.53s/it][A[A[A[A[A[A[A






 13%|█▎        | 4/31 [00:30<03:25,  7.61s/it][A[A[A[A[A[A[A






 16%|█▌        | 5/31 [00:37<03:18,  7.62s/it][A[A[A[A[A[A[A






 19%|█▉        | 6/31 [00:44<03:02,  7.28s/it][A[A[A[A[A[A[A






 23%|██▎       | 7/31 [00:51<02:55,  7.33s/it][A[A[A[A[A[A[A






 26%|██▌       | 8/31 [00:59<02:48,  7.31s/it][A[A[A[A[A[A[A






 29%|██▉       | 9/31 [01:05<02:35,  7.07s/it][A[A[A[A[A[A[A






 32%|███▏      | 10/31 [01:13<02:31,  7.22s/it][A[A[A[A[A[A[A






 35%|███▌      | 11/31 [01:19<02:21,  7.07s/it][A[A[A[A[A[A[A






 39%|███▊      | 12/31 [01:27<02:15,  7.12s/it][A[A[A[A[A[A[A






 42%|████▏     | 13/31 [01:35<02:13,  7.41s/it][A[A[A[A[A[A[A






 45%|████▌     | 14/31 [01:

TRAINING...
STARTING EPOCH  0


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)
  next_m.mul_(beta1).add_(1 - beta1, grad)







  0%|          | 1/1635 [00:23<10:52:16, 23.95s/it][A[A[A[A[A[A[A






  0%|          | 2/1635 [00:46<10:37:03, 23.41s/it][A[A[A[A[A[A[A






  0%|          | 3/1635 [01:11<10:50:51, 23.93s/it][A[A[A[A[A[A[A






  0%|          | 4/1635 [01:35<10:51:21, 23.96s/it][A[A[A[A[A[A[A






  0%|          | 5/1635 [02:00<11:03:53, 24.44s/it][A[A[A[A[A[A[A






  0%|          | 6/1635 [02:19<10:18:28, 22.78s/it][A[A[A[A[A[A[A






  0%|          | 7/1635 [02:43<10:26:54, 23.10s/it][A[A[A[A[A[A[A






  0%|          | 8/1635 [03:10<10:56:28, 24.21s/it][A[A[A[A[A[A[A






  1%|          | 9/1635 [03:36<11:07:36, 24.63s/it][A[A[A[A[A[A[A






  1%|          | 10/1635 [03:56<