In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging as log
log.basicConfig(level=log.DEBUG)

In [2]:
import sys
sys.path.append('../')
from baselines.utils import *
import os

In [3]:
import torch

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE

device(type='cuda')

In [4]:
cfg['model_name'] = 'distilbert-base-uncased'
cfg

{'DIR': '../dataset/',
 'train_path': 'train.json',
 'test_path': 'test.json',
 'dev_path': 'dev.json',
 'model_name': 'distilbert-base-uncased',
 'max_length': 512,
 'models_save_dir': '/scratch/shu7bh/contract_nli/models',
 'dataset_dir': '/scratch/shu7bh/contract_nli/dataset'}

In [5]:
# create dir if not exists
from pathlib import Path
Path(cfg["models_save_dir"]).mkdir(parents=True, exist_ok=True)
Path(cfg["dataset_dir"]).mkdir(parents=True, exist_ok=True)

In [6]:
# tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'], use_fast=True)
# bert = AutoModelForMaskedLM.from_pretrained(cfg['model_name'])

# tokenizer.save_pretrained(cfg['models_save_dir'])
# bert.save_pretrained(cfg['models_save_dir'])

In [7]:
tokenizer = AutoTokenizer.from_pretrained(cfg['models_save_dir'], use_fast=True)
bert = AutoModelForMaskedLM.from_pretrained(cfg['models_save_dir'])

In [8]:
# train_data = load_data(os.path.join(cfg['DIR'], cfg['train_path']))
# dev_data = load_data(os.path.join(cfg['DIR'], cfg['dev_path']))
# test_data = load_data(os.path.join(cfg['DIR'], cfg['test_path']))

In [9]:
# hypothesis = get_hypothesis(train_data)

In [10]:
from icecream import ic

In [11]:
from torch.utils.data import Dataset
import torch

class NLIDataset(Dataset):
    def __init__(self, data, tokenizer, hypothesis, max_len = 475):
        self.data = data
        self.tokenizer = tokenizer
        self.hypothesis = hypothesis
        self.max_len = max_len
        self.spans = []
        self.label_dict = get_labels()

        for i, doc in enumerate(self.data['documents']):
            for j, span in enumerate(doc['spans']):
                self.spans.append({
                    'doc_id': i,
                    'span_id': j,
                    'text': doc['text'][span[0]:span[1]]
                })

        self.data_points = []
        self.span_label = []
        self.nli_label = []

        for key, val in self.hypothesis.items():
            for span in self.spans:
                self.span_label.append(int(span['span_id'] in self.data['documents'][span['doc_id']]['annotation_sets'][0]['annotations'][key]['spans']))

                self.data_points.append({ 'hypotheis': val, 'premise': span['text'] })

                self.nli_label.append(self.label_dict[self.data['documents'][span['doc_id']]['annotation_sets'][0]['annotations'][key]['choice']])

                if self.nli_label[-1] != self.label_dict['NotMentioned'] and self.span_label[-1] == 0:
                    self.nli_label[-1] = self.label_dict['Ignore']


        self.tokenized_data = self.tokenizer(
            [data_point['hypotheis'] for data_point in self.data_points],
            [data_point['premise'] for data_point in self.data_points],
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        self.sep_indices = torch.where(self.tokenized_data['input_ids'] == self.tokenizer.sep_token_id)[1]

        self.sep_indices = self.sep_indices[::2]

    def __len__(self):
        return len(self.data_points)

    def __getitem__(self, idx):
        return \
            self.tokenized_data['input_ids'][idx], \
            self.tokenized_data['attention_mask'][idx], \
            self.tokenized_data['token_type_ids'][idx], \
            torch.tensor(self.span_label[idx], dtype=torch.long), \
            torch.tensor(self.nli_label[idx], dtype=torch.long), \
            self.sep_indices[idx]

In [12]:
# train_dataset = NLIDataset(train_data, tokenizer, hypothesis)
# dev_dataset = NLIDataset(dev_data, tokenizer, hypothesis)
# test_dataset = NLIDataset(test_data, tokenizer, hypothesis)

In [13]:
# save the datasets
# torch.save(train_dataset, os.path.join(cfg['dataset_dir'], 'train_dataset.pt'))
# torch.save(dev_dataset, os.path.join(cfg['dataset_dir'], 'dev_dataset.pt'))
# torch.save(test_dataset, os.path.join(cfg['dataset_dir'], 'test_dataset.pt'))

In [14]:
# load the datasets
train_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'train_dataset.pt'))
dev_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'dev_dataset.pt'))
test_dataset = torch.load(os.path.join(cfg['dataset_dir'], 'test_dataset.pt'))

In [15]:
train_dataset[0]

(tensor([  101,  4909,  2283,  4618,  2025,  7901,  3992,  2151,  5200,  2029,
          7861, 23684,  5860, 10483,  2075,  2283,  1005,  1055, 18777,  2592,
          1012,   102,  2512,  1011, 19380,  1998, 18777,  3012,  3820,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [16]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
dev_loader = DataLoader(dev_dataset, batch_size=32, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, pin_memory=True)

In [17]:
from tqdm import tqdm
import numpy as np

In [27]:
from torch import nn
import torch.nn.functional as F

class ContractNLI(nn.Module):
    def __init__(self, bert, num_labels=4):
        super().__init__()
        self.bert = bert
        self.embedding_dim = self.bert.config.hidden_size
        self.bert.eval()
        for param in self.bert.parameters():
            param.requires_grad = False
        self.num_labels = num_labels

        self.span_classifier = nn.Linear(self.embedding_dim, 2)
        self.nli_classifier = nn.Linear(self.embedding_dim, self.num_labels)

    def forward(self, input_ids, attention_mask, token_type_ids, sep_index):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)

        ic(len(outputs))

        sequence_output = outputs.hidden_states[-1]
        ic(sequence_output.shape)

        ic(sep_index.shape)
        ic(sep_index)
        ic(sequence_output[:, sep_index, :].shape)

        span_logits = self.span_classifier(torch.gather(sequence_output, 1, sep_index.unsqueeze(1).expand(-1, sequence_output.shape[-1]).unsqueeze(1)).squeeze(1))
        ic(span_logits.shape)

        nli_logits = self.nli_classifier(sequence_output[:, 0, :])
        ic(nli_logits.shape)

        return span_logits, nli_logits

    def fit(self, train_loader, dev_loader, epochs, lr, ignore_index, lambda_):
        ic.disable()
        self.lambda_ = lambda_
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.nli_criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.span_criterion = nn.CrossEntropyLoss()
        
        for epoch in range(epochs):
            print(f'Epoch: {epoch + 1}/{epochs}')
            train_loss = self.__train(train_loader)
            print(f'Train Loss: {train_loss:.4f}')
            dev_loss = self.__validate(dev_loader)
            print(f'Dev Loss: {dev_loss:.4f}')

        ic.enable()

    def __train(self, train_loader):
        self.train()
        total_loss = []

        pbar = tqdm(train_loader)

        for batch in pbar:
            loss = self.__call(batch)
            total_loss.append(loss.item())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            pbar.set_description(f'Loss: {loss.item():.4f}, Average Loss: {np.mean(total_loss):.4f}')

        return np.mean(total_loss)
    
    def __validate(self, dev_loader):
        self.eval()
        total_loss = []

        with torch.no_grad():
            pbar = tqdm(dev_loader)

            for batch in pbar:
                loss = self.__call(batch)
                total_loss.append(loss.item())

                pbar.set_description(f'Loss: {loss.item():.4f}, Average Loss: {np.mean(total_loss):.4f}')

        return np.mean(total_loss)
        
    def __call(self, batch):
        input_ids, attention_mask, token_type_ids, span_label, nli_label, sep_index = tuple(t.to(DEVICE) for t in batch)
        ic(input_ids.shape)
        ic(attention_mask.shape)
        ic(token_type_ids.shape)

        span_logits, nli_logits = self(input_ids, attention_mask, token_type_ids, sep_index)

        span_loss = self.span_criterion(span_logits, span_label)
        nli_loss = self.nli_criterion(nli_logits, nli_label)

        loss = span_loss + self.lambda_ * nli_loss

        return loss


In [28]:
model = ContractNLI(bert, num_labels=4).to(DEVICE)

In [29]:
train_dataset.__len__()

559215

In [30]:
model.fit(train_loader, dev_loader, epochs=5, lr=1e-3, ignore_index=get_labels()['Ignore'], lambda_=1)

Epoch: 1/5


Loss: 0.0483, Average Loss: 0.4875:   0%|          | 51/17476 [00:26<2:32:21,  1.91it/s]


KeyboardInterrupt: 