In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging as log
log.basicConfig(level=log.DEBUG)

In [None]:
import sys
sys.path.append('../')
from baselines.utils import *
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
import torch

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE

In [None]:
cfg['model_name'] = 'bert-base-uncased'
cfg['batch_size'] = 32
cfg

In [None]:
# create dir if not exists
from pathlib import Path
Path(cfg["models_save_dir"]).mkdir(parents=True, exist_ok=True)
Path(cfg["dataset_dir"]).mkdir(parents=True, exist_ok=True)

In [None]:
# tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'])
# bert = AutoModelForMaskedLM.from_pretrained(cfg['model_name'])

# tokenizer.save_pretrained(cfg['models_save_dir'])
# bert.save_pretrained(cfg['models_save_dir'])

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfg['models_save_dir'])
bert = AutoModelForMaskedLM.from_pretrained(cfg['models_save_dir'])

In [None]:
from icecream import ic

In [236]:
from torch.utils.data import Dataset
import torch

class NLIDataset(Dataset):
    def __init__(self, documents, tokenizer, hypothesis, context_sizes, surround_character_size):
        spans = []
        self.label_dict = get_labels()
        self.tokenizer = tokenizer

        # for i, doc in enumerate(documents):
        #     for j, span in enumerate(doc['spans']):
        #         spans.append({
        #             'doc_id': i,
        #             'span_id': j,
        #             'text': doc['text'][span[0]:span[1]]
        #         })

        data_points = []
        self.span_label = []
        self.nli_label = []

        contexts = [{}]

        for context_size in context_sizes:
            for i, doc in enumerate(documents):
                ic(i)
                char_idx = 0
                while char_idx < len(doc['text']):
                    ic(char_idx)
                    document_spans = doc['spans']
                    cur_context = {
                        'doc_id': i,
                        'start_char_idx': char_idx,
                        'end_char_idx': char_idx + context_size,
                        'spans' : [],
                    }

                    ic(cur_context)

                    for j, (start, end) in enumerate(document_spans):
                        ic(j)
                        if end <= char_idx:
                            continue

                        cur_context['spans'].append({
                            'start_char_idx': max(start, char_idx),
                            'end_char_idx': min(end, char_idx + context_size),
                            'marked': start >= char_idx and end <= char_idx + context_size,
                            'span_id': j
                        })

                        if end > char_idx + context_size:
                            break

                    contexts.append(cur_context)
                    if len(cur_context['spans']) == 1 and not cur_context['spans'][0]['marked']:
                        char_idx = cur_context['end_char_idx'] - surround_character_size
                    else:
                        char_idx = cur_context['spans'][-1]['start_char_idx'] - surround_character_size

        self.contexts = contexts


        #     for span in spans:
        #         self.span_label.append(int(span['span_id'] in documents[span['doc_id']]['annotation_sets'][0]['annotations'][key]['spans']))

        #         data_points.append({ 'hypotheis': val, 'premise': span['text'] })

        #         self.nli_label.append(self.label_dict[documents[span['doc_id']]['annotation_sets'][0]['annotations'][key]['choice']])

        #         if self.nli_label[-1] != self.label_dict['NotMentioned'] and self.span_label[-1] == 0:
        #             self.nli_label[-1] = self.label_dict['Ignore']

        # self.data_points = data_points[:100]

    def __len__(self):
        return len(self.data_points)

    def __getitem__(self, idx):
        tokenized_data = self.tokenizer(
            [self.data_points[idx]['hypotheis']],
            [self.data_points[idx]['premise']],
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        tokenized_data['input_ids'] = tokenized_data['input_ids'].squeeze()
        tokenized_data['attention_mask'] = tokenized_data['attention_mask'].squeeze()
        tokenized_data['token_type_ids'] = tokenized_data['token_type_ids'].squeeze()

        sep_indices = torch.where(tokenized_data['input_ids'] == self.tokenizer.sep_token_id)[0][0]

        return {
            'input_ids': tokenized_data['input_ids'],
            'attention_mask': tokenized_data['attention_mask'],
            'token_type_ids': tokenized_data['token_type_ids'],
            'span_label': torch.tensor(self.span_label[idx], dtype=torch.long),
            'nli_label': torch.tensor(self.nli_label[idx], dtype=torch.long), 
            'sep_indices': sep_indices
        }

In [237]:
train_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['train_path']))
dev_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['dev_path']))
test_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['test_path']))

hypothesis = get_hypothesis(train_data)

train_data = train_data['documents']
dev_data = dev_data['documents']
test_data = test_data['documents']

train_data = train_data[:2]
dev_data = dev_data[:2]
test_data = test_data[:2]

ic(len(train_data), len(dev_data), len(test_data))
train_dataset = NLIDataset(train_data, tokenizer, hypothesis, [100], 10)
dev_dataset = NLIDataset(dev_data, tokenizer, hypothesis, [100], 10)
test_dataset = NLIDataset(test_data, tokenizer, hypothesis, [100], 10)

del train_data
del dev_data
del test_data
del hypothesis
# save the datasets
torch.save(train_dataset, os.path.join(cfg['dataset_dir'], 'train_dataset.pt'))
torch.save(dev_dataset, os.path.join(cfg['dataset_dir'], 'dev_dataset.pt'))
torch.save(test_dataset, os.path.join(cfg['dataset_dir'], 'test_dataset.pt'))

ic| len(train_data): 2, len(dev_data): 2, len(test_data): 2
ic| i: 0
ic| char_idx: 0
ic| cur_context: {'doc_id': 0, 'end_char_idx': 100, 'spans': [], 'start_char_idx': 0}
ic| j: 0
ic| j: 1
ic| char_idx: 35
ic| cur_context: {'doc_id': 0, 'end_char_idx': 135, 'spans': [], 'start_char_idx': 35}
ic| j: 0
ic| j: 1
ic| j: 2
ic| char_idx: 123
ic| cur_context: {'doc_id': 0, 'end_char_idx': 223, 'spans': [], 'start_char_idx': 123}
ic| j: 0
ic| j: 1
ic| j: 2
ic| char_idx: 123
ic| cur_context: {'doc_id': 0, 'end_char_idx': 223, 'spans': [], 'start_char_idx': 123}
ic| j: 0
ic| j: 1
ic| j: 2
ic| char_idx: 123
ic| cur_context: {'doc_id': 0, 'end_char_idx': 223, 'spans': [], 'start_char_idx': 123}
ic| j: 0
ic| j: 1
ic| j: 2
ic| char_idx: 123
ic| cur_context: {'doc_id': 0, 'end_char_idx': 223, 'spans': [], 'start_char_idx': 123}
ic| j: 0
ic| j: 1
ic| j: 2
ic| char_idx: 123
ic| cur_context: {'doc_id': 0, 'end_char_idx': 223, 'spans': [], 'start_char_idx': 123}
ic| j: 0
ic| j: 1
ic| j: 2
ic| char_idx: 1

KeyboardInterrupt: 

In [None]:
train_dataset.__getitem__(0)

In [None]:
# load the datasets
train_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'train_dataset.pt'))
dev_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'dev_dataset.pt'))
# test_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'test_dataset.pt'))

In [None]:
from tqdm import tqdm
import numpy as np

In [None]:
from torch import nn
class ContractNLI(nn.Module):
    def __init__(self, bert, num_labels, ignore_index):
        super().__init__()
        self.bert = bert
        self.bert.eval()
        for param in self.bert.parameters():
            param.requires_grad = False

        self.embedding_dim = self.bert.config.hidden_size
        self.num_labels = num_labels
        self.lambda_ = 1
        self.nli_criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.span_criterion = nn.BCEWithLogitsLoss()

        self.span_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim // 2),
            nn.ReLU(),
            nn.Linear(self.embedding_dim // 2, 1)
        )

        self.nli_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim // 2),
            nn.ReLU(),
            nn.Linear(self.embedding_dim // 2, self.num_labels)
        )

    def forward(self, input_ids, attention_mask, token_type_ids, sep_indices):
        outputs = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True).hidden_states[-1]

        span_logits = self.span_classifier(torch.gather(outputs, 1, sep_indices.unsqueeze(1).expand(-1, outputs.shape[-1]).unsqueeze(1)).squeeze(1))

        nli_logits = self.nli_classifier(outputs[:, 0, :])

        return span_logits, nli_logits

In [None]:
# import wandb

# wandb.init(project="contract-nli", entity="contract-nli-db")

In [None]:
from typing import Dict, List, Optional
from torch.utils.data import Dataset
from transformers import Trainer

class ContractNLITrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        span_label = inputs.pop('span_label')
        nli_label = inputs.pop('nli_label')

        outputs = model(**inputs)
        span_logits, nli_logits = outputs[0], outputs[1]

        span_loss = self.model.span_criterion(span_logits, span_label.reshape(-1, 1).float())
        nli_loss = self.model.nli_criterion(nli_logits, nli_label)

        if torch.isnan(nli_loss):
            nli_loss = torch.tensor(0, dtype=torch.float32, device=DEVICE)

        loss = span_loss + self.model.lambda_ * nli_loss

        if torch.isnan(loss):
            ic(inputs['input_ids'])
            ic(nli_label)
            ic(nli_logits)

        return (loss, outputs) if return_outputs else loss

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    auto_find_batch_size=True,
    output_dir='./results',          # output directory
    num_train_epochs=10,             # total number of training epochs
    # warmup_steps=10,               # number of warmup steps for learning rate scheduler
    logging_steps=2,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=1,
    load_best_model_at_end=True,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    run_name='1',
    label_names=['nli_label', 'span_label'],
    report_to='none'
)

In [None]:
trainer = ContractNLITrainer(
    model=ContractNLI(bert, len(get_labels()), ignore_index=get_labels()['Ignore']).to(DEVICE),
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=dev_dataset,            # evaluation dataset
)

In [None]:
trainer
trainer.train()