## Fine-tuning

In [1]:
from __future__ import print_function, absolute_import, division
%load_ext autoreload
%autoreload 2
%matplotlib widget

import sys, os, json, time, datetime, logging, multiprocessing, itertools
import sqlite3
import tqdm
import json
import ast
from pathlib import Path
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=7
%env OMP_NUM_THREADS=15
%env OPENBLAS_NUM_THREADS=15
%env OPENMP_NUM_THREADS=15
%env MKL_NUM_THREADS=15
%env HF_HOME=/shared/zjiayao/cache
%env ALLENNLP_CACHE_ROOT=/shared/zjiayao/cache/allennlp


env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=7
env: OMP_NUM_THREADS=15
env: OPENBLAS_NUM_THREADS=15
env: OPENMP_NUM_THREADS=15
env: MKL_NUM_THREADS=15
env: HF_HOME=/shared/zjiayao/cache
env: ALLENNLP_CACHE_ROOT=/shared/zjiayao/cache/allennlp


In [3]:
def console_log(msg, end='\n'):
    os.write(1, ('[LOG/{}]'.format(multiprocessing.current_process().name)+msg+end).encode('utf-8'))


In [4]:
import torch
import spacy
from spacy.matcher import PhraseMatcher
import transformers
import allennlp
import lemminflect

print(torch.cuda.is_available())
TORCH_DEV = torch.device(f'cuda:0') if torch.cuda.is_available() \
                                    else torch.device("cpu")


import nltk
import src
import src.utils as utils


True


In [5]:


DATA_PATH = Path('/shared/zjiayao/data')
LOCAL_DATA_PATH = Path('/shared/zjiayao/data')
SAVING_PATH = Path('/shared/zjiayao/experiment_data')

# nltk.data.path = [str(LOCAL_DATA_PATH / 'nltk_data')] + nltk.data.path
nyt_path =  Path('/shared/corpora-tmp/nyt_allennlp_srl/')


## Load data into model

In [7]:
spacy_model = spacy.load('en_core_web_md')

In [10]:
DB_PATH = 'nyt_fine_tune.csv'
MODEL_PATH = "/shared/zjiayao/cache/tmp" # where to save checkpoints

In [11]:
utils.set_global_logging_level(logging.ERROR)

Steps:

    * extract events besides the temporal connector
    * impute training instances
    * feed to the model

#### Example:

In [12]:
import allennlp_models.structured_prediction.models.srl
import allennlp_models.pretrained
from spacy.tokenizer import Tokenizer


In [13]:
bert_srl = allennlp_models.pretrained.load_predictor("structured-prediction-srl-bert", cuda_device=0)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

### Helper functions

In [14]:
import ast

def col_print(*args, cw=12, sep='|'):
    print(f" {sep} ".join(('{'+f":<{cw}"+'}').format(s) for s in args))

#### Load SRL from csv

In [15]:
def get_idx_or_none(arr, v, default=-1):
    try:
        return arr.index(v)
    except Exception as e:
        return default
    
def get_span_for_tag_bio(tokens, tags, tgt, default=[], verbose=True):
    try:
        return tokens[tags.index(tgt):-tags[::-1].index(tgt)]
    except Exception as e:
        if verbose:
            print(f"error get tag {tgt} from the tokens {tokens}: {e}")
            utils.console_log(f"error get tag {tgt} from the tokens {tokens}: {e}")
        return default    
    
def get_span_for_tag(tokens, tags, tgt, default="NAN", verbose=False):
    tgt1, tgt2 = f"B-{tgt}", f"I-{tgt}"
    return [get_token_at_tag(tokens, tags, tgt1, default=default,verbose=verbose)] + \
                get_span_for_tag_bio(tokens, tags, tgt2,verbose=verbose)

def get_token_at_tag(tokens, tags, tgt, default="NAN", verbose=True):
    """
    tokens: Iterable of tokens
    tags: Iterable of tags
    tgt: target tag
    default: default value when errs
    
    returns an item in the Iterable tokens
    
    """
    try:
        return tokens[tags.index(tgt)]
    except Exception as e:
        if verbose:
            print(f"error get tag {tgt} from the tokens {tokens}: {e}")
            utils.console_log(f"error get tag {tgt} from the tokens {tokens}: {e}")
        return default
def contains_tmp_arg(tokens, tags, tmps=['before', 'after'], verbose=True):

    tk = get_token_at_tag(tokens, tags, "B-ARGM-TMP", default=None, verbose=verbose)
    if tk is None or (tk.lemma_.lower() not in tmps):
        return None
    
    return tk.text
    

In [16]:
def extract_tmp_relations(dt):
    columns = [
        "sent", "verb", "description", "tags", "tmp_connector", "tmp_verb", "tmp_description", "tmp_tags"
    ]
    
    dt['tmp_connector'] = dt.apply(lambda s : contains_tmp_arg(s['tokens'],
                                                          s['tags_'], 
                                                           verbose=False), 
                               axis=1)

    dt["v_idx"] = dt["tags_"].apply(lambda s : get_idx_or_none(s, 'B-V'))
    dt["arg_tmp_idx"] = dt["tags_"].apply(lambda s : get_idx_or_none(s, 'B-ARGM-TMP'))

    dt_tmp = dt[(~dt['tmp_connector'].isna()) & (dt["arg_tmp_idx"] > 0) &  dt.apply(
        lambda s :  (dt["v_idx"] > s["arg_tmp_idx"]).any(), axis=1)
    ]

    if len(dt_tmp) == 0:
        return pd.DataFrame(columns=columns)
    
    dt_tmp_arg = dt_tmp.apply(
        lambda s : dt.loc[dt[dt["v_idx"] > s["arg_tmp_idx"]].v_idx.idxmin()][["verb", "description", "tags", "tags_"]].rename(
            {v: f"tmp_{v}" for v in ["verb", "description", "tags", "tags_"]}
        ), axis=1
    )
    
    return pd.concat([dt_tmp, dt_tmp_arg], axis=1).reset_index(drop=True)[columns]
    
    

### Fine-tuning

In [None]:
con = sqlite3.connect(DB_PATH)

In [None]:
srl_procs = pd.read_sql(sql="select * from extracted_rel;", con=con)

In [None]:
tmp_tokens = {'before': bert_srl._tokenizer.tokenize('before'),
              'after':  bert_srl._tokenizer.tokenize('after')}

In [None]:
def get_VBD(v):
    return lemminflect.getInflection('survive', tag='VBD')[0]


In [None]:
def get_instance(v, arg0, arg1, tmp_v, tmp_arg0, tmp_arg1, tmp, tmp_tokens, 
                 txt=False, tokenizer=bert_srl._tokenizer.tokenize, verbose=True):
    """
    except for tmp, all other arguments are Spacy tokens
    """
    reverse_tmp = {'before':'after', 'after':'before'}
    try:
        v, tmp_v = tokenizer(get_VBD(v)), tokenizer(get_VBD(tmp_v))
        s1 = arg0 + v + arg1 + tmp_tokens[tmp] + tmp_arg0 + tmp_v + tmp_arg1
        s2 = arg0 + v + arg1 + tmp_tokens[reverse_tmp[tmp]] + tmp_arg0 + tmp_v + tmp_arg1
        if txt:
            return [' '.join([s.text if hasattr(s, 'text') else str(s) for s in ss]) for ss in [s1,s2]]
        return [s1, s2]
    except Exception as e:
        if verbose:
            print(f"exception on {v}/{tmp_v}: {e}")
        console_log(f"exception on {v}/{tmp_v}: {e}")
        return []
        
def get_instance_from_df(ds, tmp_tokens, txt=False, verbose=True):
    """
    except for tmp, all other arguments are Spacy tokens
    """
    tmp = ds['tmp_connector']
    dft = "NAN"
    g_ = lambda t: get_span_for_tag(ds.tokens, ds.tags_, t, default=dft)
    gt_ = lambda t: get_span_for_tag(ds.tokens, ds.tmp_tags_, t, default=dft)

    checklist = [g_("V"), g_("ARG0"), g_("ARG1"),
                       gt_("V"), gt_("ARG0"), gt_("ARG1")]
#     print(checklist, dft in checklist)
    if [dft] in checklist:
        return []
    return get_instance(*checklist,
                       tmp, tmp_tokens, txt=txt, verbose=verbose)
    
    

In [41]:
con = sqlite3.connect('/shared/zjiayao/exp_db/nyt_temp_keys.db')
for i in tqdm.tqdm(range(n)):

    if i == n - 1:
        srl_df = srl_procs.iloc[i*bs:].copy()
    else:
        srl_df = srl_procs.iloc[i*bs:(i+1)*bs].copy()
    
    srl_df = srl_df.reset_index(drop=True).drop(columns=["index"])
    srl_df["tokens"] = srl_df["sent"].apply(bert_srl._tokenizer.tokenize)
    srl_df["tags_"] = srl_df["tags"].apply(ast.literal_eval)
    srl_df["tmp_tags_"] = srl_df["tmp_tags"].apply(ast.literal_eval)

    raw_df = pd.DataFrame.from_dict({'sent':
     list(itertools.chain(*list(srl_df.apply(
         lambda s : get_instance_from_df(s, tmp_tokens, txt=True), axis=1)
                               )
                         ))
    })
    raw_df.to_sql(name="srl_data",con=con, if_exists='append')
    console_log(f"{i:04}/{n:04}")

100%|██████████| 5/5 [00:11<00:00,  2.30s/it]


In [27]:
# srl_df = srl_procs.iloc[:100].copy().reset_index(drop=True).drop(columns=["index"])
# srl_df["tokens"] = srl_df["sent"].apply(bert_srl._tokenizer
# .tokenize)
# srl_df["tags_"] = srl_df["tags"].apply(ast.literal_eval)
# srl_df["tmp_tags_"] = srl_df["tmp_tags"].apply(ast.literal_eval)

# raw_df = pd.DataFrame.from_dict({'sent':
#  list(itertools.chain(*list(srl_df.apply(
#      lambda s : get_instance_from_df(s, tmp_tokens, txt=True), axis=1)
#                            )
#                      ))
# })

In [29]:
raw_df = pd.DataFrame.from_dict(raw_df)

### Model finetune

In [7]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM, Trainer,DataCollatorForLanguageModeling,TrainingArguments
import datasets

In [10]:
ft_dataset.read_csv(DATA_PATH)
raw_data = datasets.Dataset.from_pandas(ft_dataset)

In [12]:
transformers.set_seed(hash("some_random_str") % (2 **32 - 1))

## MaskedLM

In [15]:
model = AutoModelForMaskedLM.from_pretrained("roberta-base")
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

In [17]:
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))

Embedding(50265, 768, padding_idx=1)

In [19]:
tr_cfg = TrainingArguments(
    output_dir=MODEL_PATH,
    do_train=True,
    do_eval=False,
    save_total_limit=2,
    seed=rd_seed,
    disable_tqdm=False,
)

In [20]:
tokenized_datasets = raw_data.map(
    lambda s: tokenizer(s['sent'], return_special_tokens_mask=True),
    batched=True, num_proc=4,
    batch_size=500,
)


In [21]:
trainer = Trainer(
    
    model=model,
    train_dataset=tokenized_datasets,
    args=tr_cfg,
     tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                                 mlm_probability=0.15)
    
)

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


In [None]:
train_result = trainer.train(model_path=tmp_path)

Step,Training Loss
500,2.0208
1000,1.9132
1500,1.9868
2000,1.9393
2500,1.9249
3000,1.9124
3500,1.9182
4000,1.8824
4500,1.8849
5000,1.912


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter serve