In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging as log
log.basicConfig(level=log.DEBUG)

In [2]:
import sys
sys.path.append('../')
from baselines.utils import *
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
import torch

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE

device(type='cuda')

In [4]:
cfg['model_name'] = 'bert-base-uncased'
cfg['batch_size'] = 32
cfg

{'raw_data_dir': '../dataset/',
 'train_path': 'train.json',
 'test_path': 'test.json',
 'dev_path': 'dev.json',
 'model_name': 'bert-base-uncased',
 'max_length': 512,
 'models_save_dir': '/scratch/shu7bh/contract_nli/models',
 'dataset_dir': '/scratch/shu7bh/contract_nli/dataset',
 'batch_size': 32}

In [5]:
# create dir if not exists
from pathlib import Path
Path(cfg["models_save_dir"]).mkdir(parents=True, exist_ok=True)
Path(cfg["dataset_dir"]).mkdir(parents=True, exist_ok=True)

In [6]:
# tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'])
# bert = AutoModelForMaskedLM.from_pretrained(cfg['model_name'])

# tokenizer.save_pretrained(cfg['models_save_dir'])
# bert.save_pretrained(cfg['models_save_dir'])

In [7]:
tokenizer = AutoTokenizer.from_pretrained(cfg['models_save_dir'])
bert = AutoModelForMaskedLM.from_pretrained(cfg['models_save_dir'])

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpmn5zwib1
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpmn5zwib1/_remote_module_non_scriptable.py


In [8]:
from icecream import ic

In [9]:
from torch.utils.data import Dataset
import torch

class NLIDataset(Dataset):
    def __init__(self, documents, tokenizer, hypothesis):
        spans = []
        self.label_dict = get_labels()

        for i, doc in enumerate(documents):
            for j, span in enumerate(doc['spans']):
                spans.append({
                    'doc_id': i,
                    'span_id': j,
                    'text': doc['text'][span[0]:span[1]]
                })

        data_points = []
        self.span_label = []
        self.nli_label = []

        ic()

        for key, val in hypothesis.items():
            for span in spans:
                self.span_label.append(int(span['span_id'] in documents[span['doc_id']]['annotation_sets'][0]['annotations'][key]['spans']))

                data_points.append({ 'hypotheis': val, 'premise': span['text'] })

                self.nli_label.append(self.label_dict[documents[span['doc_id']]['annotation_sets'][0]['annotations'][key]['choice']])

                if self.nli_label[-1] != self.label_dict['NotMentioned'] and self.span_label[-1] == 0:
                    self.nli_label[-1] = self.label_dict['Ignore']

        del spans
        ic()


        #input_ids = [CLS] [HYPOTHESES] [SEP] [PREMISE] [SEP] [PAD] [PAD] ...
        self.tokenized_data = tokenizer(
            [data_point['hypotheis'] for data_point in data_points],
            [data_point['premise'] for data_point in data_points],
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        ic()

        self.sep_indices = torch.where(self.tokenized_data['input_ids'] == tokenizer.sep_token_id)[1]

        self.sep_indices = self.sep_indices[::2]

        ic()

    def __len__(self):
        return len(self.tokenized_data['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokenized_data['input_ids'][idx],
            'attention_mask': self.tokenized_data['attention_mask'][idx],
            'token_type_ids': self.tokenized_data['token_type_ids'][idx],
            'span_label': torch.tensor(self.span_label[idx], dtype=torch.long),
            'nli_label': torch.tensor(self.nli_label[idx], dtype=torch.long),
            'sep_indices': self.sep_indices[idx]
        }
            # self.tokenized_data['input_ids'][idx], \
            # self.tokenized_data['attention_mask'][idx], \
            # self.tokenized_data['token_type_ids'][idx], \
            # torch.tensor(self.span_label[idx], dtype=torch.long), \
            # torch.tensor(self.nli_label[idx], dtype=torch.long), \
            # self.sep_indices[idx]

In [10]:
# train_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['train_path']))
# dev_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['dev_path']))
# test_data = load_data(os.path.join(cfg['raw_data_dir'], cfg['test_path']))

# hypothesis = get_hypothesis(train_data)

# train_data = train_data['documents']
# dev_data = dev_data['documents']
# test_data = test_data['documents']

# train_data = train_data[:10]
# dev_data = dev_data[:10]
# test_data = test_data[:10]

# ic(len(train_data), len(dev_data), len(test_data))
# train_dataset = NLIDataset(train_data, tokenizer, hypothesis)
# dev_dataset = NLIDataset(dev_data, tokenizer, hypothesis)
# test_dataset = NLIDataset(test_data, tokenizer, hypothesis)

# del train_data
# del dev_data
# del test_data
# del hypothesis
# # save the datasets
# torch.save(train_dataset, os.path.join(cfg['dataset_dir'], 'train_dataset.pt'))
# torch.save(dev_dataset, os.path.join(cfg['dataset_dir'], 'dev_dataset.pt'))
# torch.save(test_dataset, os.path.join(cfg['dataset_dir'], 'test_dataset.pt'))

In [11]:
# load the datasets
train_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'train_dataset.pt'))
dev_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'dev_dataset.pt'))
# test_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'test_dataset.pt'))

In [12]:
# load the dataloaders
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=cfg['batch_size'], shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=cfg['batch_size'], shuffle=True)

In [13]:
from tqdm import tqdm
import numpy as np

In [30]:
from torch import nn
class ContractNLI(nn.Module):
    def __init__(self, bert, num_labels, ignore_index):
        super().__init__()
        self.bert = bert
        self.bert.eval()
        for param in self.bert.parameters():
            param.requires_grad = False

        self.embedding_dim = self.bert.config.hidden_size
        self.num_labels = num_labels
        self.labmda = 1
        self.nli_criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.span_criterion = nn.BCELoss()

        self.span_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim // 2),
            nn.ReLU(),
            nn.Linear(self.embedding_dim // 2, 1)
        )

        self.nli_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim // 2),
            nn.ReLU(),
            nn.Linear(self.embedding_dim // 2, self.num_labels)
        )

    def forward(self, input_ids, attention_mask, token_type_ids, sep_indices):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True).hidden_states[-1]

        span_logits = self.span_classifier(torch.gather(outputs, 1, sep_indices.unsqueeze(1).expand(-1, outputs.shape[-1]).unsqueeze(1)).squeeze(1))

        nli_logits = self.nli_classifier(outputs[:, 0, :])

        return span_logits, nli_logits

    def fit(self, train_loader, dev_loader, epochs, lr, lambda_):
        self.lambda_ = lambda_
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        
        for epoch in range(epochs):
            print(f'Epoch: {epoch + 1}/{epochs}')
            train_loss = self.__train(train_loader)
            print(f'Train Loss: {train_loss:.4f}')
            dev_loss = self.__validate(dev_loader)
            print(f'Dev Loss: {dev_loss:.4f}')

    def __train(self, train_loader):
        self.train()
        total_loss = []

        pbar = tqdm(train_loader)

        for batch in pbar:
            loss = self.__call(batch)
            total_loss.append(loss.item())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            pbar.set_description(f'Loss: {loss.item():.4f}, Average Loss: {np.mean(total_loss):.4f}')

        return np.mean(total_loss)
    
    def __validate(self, dev_loader):
        self.eval()
        total_loss = []

        with torch.no_grad():
            pbar = tqdm(dev_loader)

            for batch in pbar:
                loss = self.__call(batch)
                total_loss.append(loss.item())

                pbar.set_description(f'Loss: {loss.item():.4f}, Average Loss: {np.mean(total_loss):.4f}')

        return np.mean(total_loss)
        
    def __call(self, batch):
        input_ids, attention_mask, token_type_ids, span_label, nli_label, sep_index = tuple(val.to(DEVICE) for val in batch.values())

        span_logits, nli_logits = self(input_ids, attention_mask, token_type_ids, sep_index)

        # ic(span_logits.shape, nli_logits.shape, span_label.shape, nli_label.shape)

        span_loss = self.span_criterion(span_logits, span_label.view(-1, 1).float())
        nli_loss = self.nli_criterion(nli_logits, nli_label)

        loss = span_loss + self.lambda_ * nli_loss

        return loss

In [31]:
# from transformers import Trainer

# class ContractNLITrainer(Trainer):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)

#     def compute_loss(self, model, inputs, return_outputs=False):
#         nli_label = inputs.pop('nli_label')
#         span_label = inputs.pop('span_label')

#         outputs = model(**inputs)
#         span_logits, nli_logits = outputs[0], outputs[1]

#         span_loss = self.model.span_criterion(span_logits.view(-1), span_label.float())
#         nli_loss = self.model.nli_criterion(nli_logits, nli_label)

#         loss = span_loss + self.model.labmda * nli_loss

#         return (loss, outputs) if return_outputs else loss

In [32]:
# import wandb

# wandb.init(project="contract-nli", entity="contract-nli-db")

In [33]:
# give input to BERT
# bert.eval()
# for param in bert.parameters():
#     param.requires_grad = False
# bert = bert.to(DEVICE)
# bert_output

In [34]:
model = ContractNLI(bert, len(get_labels()), ignore_index=get_labels()['Ignore']).to(DEVICE)

In [35]:
from torchinfo import summary

In [36]:
summary(model)

Layer (type:depth-idx)                                       Param #
ContractNLI                                                  --
├─BertForMaskedLM: 1-1                                       --
│    └─BertModel: 2-1                                        --
│    │    └─BertEmbeddings: 3-1                              (23,837,184)
│    │    └─BertEncoder: 3-2                                 (85,054,464)
│    └─BertOnlyMLMHead: 2-2                                  --
│    │    └─BertLMPredictionHead: 3-3                        (24,063,546)
├─CrossEntropyLoss: 1-2                                      --
├─BCELoss: 1-3                                               --
├─Sequential: 1-4                                            --
│    └─Linear: 2-3                                           295,296
│    └─ReLU: 2-4                                             --
│    └─Linear: 2-5                                           385
│    └─Sigmoid: 2-6                                          --

In [38]:
model.fit(train_dataloader, dev_dataloader, epochs=10, lr=1e-3, lambda_=1)

In [None]:
import torch
torch.cuda.empty_cache()

import gc
gc.collect()

6178