<a href="https://colab.research.google.com/github/shivkumarganesh/Advance-Deep-Learning/blob/main/Assignment%203/NER_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip3 install git+https://github.com/huggingface/transformers

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-gn3z4zv8
  Running command git clone -q https://github.com/huggingface/transformers /tmp/pip-req-build-gn3z4zv8
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone


In [18]:
import torch 
from datasets import load_dataset
from transformers import BertTokenizerFast

# Load our training dataset and tokenizer
dataset = load_dataset('squad')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

def get_correct_alignement(context, answer):
    """ Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. """
    gold_text = answer['text'][0]
    start_idx = answer['answer_start'][0]
    end_idx = start_idx + len(gold_text)
    if context[start_idx:end_idx] == gold_text:
        return start_idx, end_idx       # When the gold label position is good
    elif context[start_idx-1:end_idx-1] == gold_text:
        return start_idx-1, end_idx-1   # When the gold label is off by one character
    elif context[start_idx-2:end_idx-2] == gold_text:
        return start_idx-2, end_idx-2   # When the gold label is off by two character
    else:
        raise ValueError()

# Tokenize our training dataset
def convert_to_features(example_batch):
    # Tokenize contexts and questions (as pairs of inputs)
    encodings = tokenizer(example_batch['context'], example_batch['question'], truncation=True)

    # Compute start and end tokens for labels using Transformers's fast tokenizers alignement methods.
    start_positions, end_positions = [], []
    for i, (context, answer) in enumerate(zip(example_batch['context'], example_batch['answers'])):
        start_idx, end_idx = get_correct_alignement(context, answer)
        start_positions.append(encodings.char_to_token(i, start_idx))
        end_positions.append(encodings.char_to_token(i, end_idx-1))
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    return encodings

encoded_dataset = dataset.map(convert_to_features, batched=True)

# Format our dataset to outputs torch.Tensor to train a pytorch model
columns = ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions']
encoded_dataset.set_format(type='torch', columns=columns)

# Instantiate a PyTorch Dataloader around our dataset
# Let's do dynamic batching (pad on the fly with our own collate_fn)
def collate_fn(examples):
    return tokenizer.pad(examples, return_tensors='pt')
dataloader = torch.utils.data.DataLoader(encoded_dataset['train'], collate_fn=collate_fn, batch_size=8)

Downloading:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.63 MiB, post-processed: Unknown size, total: 119.14 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/8.12M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

  0%|          | 0/88 [00:00<?, ?ba/s]

  0%|          | 0/11 [00:00<?, ?ba/s]

In [19]:
# Let's load a pretrained Bert model and a simple optimizer
from transformers import BertForQuestionAnswering

model = BertForQuestionAnswering.from_pretrained('distilbert-base-cased', return_dict=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

Downloading:   0%|          | 0.00/411 [00:00<?, ?B/s]

You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.


Downloading:   0%|          | 0.00/251M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-cased were not used when initializing BertForQuestionAnswering: ['distilbert.transformer.layer.1.sa_layer_norm.bias', 'distilbert.transformer.layer.3.sa_layer_norm.weight', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.4.attention.v_lin.weight', 'distilbert.transformer.layer.5.sa_layer_norm.bias', 'distilbert.transformer.layer.4.attention.k_lin.bias', 'vocab_projector.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.5.attention.v_lin.weight', 'distilbert.transformer.layer.1.ffn.lin2.bias', 'distilbert.transformer.layer.2.sa_layer_norm.weight', 'distilbert.transformer.layer.5.ffn.lin1.bias', 'distilbert.transformer.layer.5.output_layer_norm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.4.sa_layer_norm.weight', 'distilbert.transformer.layer.5.ffn.lin2.bias', 'distilbert.transformer.layer.4.ffn.lin1.weight', 'disti

In [20]:
# Now let's train our model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model.train().to(device)
for i, batch in enumerate(dataloader):
    batch.to(device)
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    model.zero_grad()
    print(f'Step {i} - loss: {loss:.3}')
    if i > 5:
        break

Step 0 - loss: 5.61
Step 1 - loss: 5.68
Step 2 - loss: 4.97
Step 3 - loss: 5.84
Step 4 - loss: 5.32
Step 5 - loss: 5.67
Step 6 - loss: 5.58
