<h1><span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Model" data-toc-modified-id="Model-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href="#DataLoaders" data-toc-modified-id="DataLoaders-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>DataLoaders</a></span></li><li><span><a href="#Higher-Loop" data-toc-modified-id="Higher-Loop-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Higher Loop</a></span></li><li><span><a href="#Evaluation" data-toc-modified-id="Evaluation-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Evaluation</a></span></li></ul></div>

In [50]:
import time
from datetime import datetime
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import higher

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from datasets import load_dataset, list_metrics, load_metric

from src.data_process import DataProcessor, filterText

## Model

In [38]:
# config = GPT2Config.from_pretrained('distilgpt2', output_hidden_states=False)
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))

Embedding(50258, 768)

In [3]:
tokenizer

PreTrainedTokenizer(name_or_path='distilgpt2', vocab_size=50257, model_max_len=1024, is_fast=False, padding_side='right', special_tokens={'bos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '[PAD]'})

In [54]:
tokenizer.bos_token
tokenizer.eos_token

'<|endoftext|>'

In [55]:
tokenizer.pad_token

'[PAD]'

## DataLoaders

In [44]:
wikitext = load_dataset(
    'wikitext', 
    'wikitext-103-raw-v1', 
    cache_dir="/Volumes/External HD/Dev/datasets/wikitext", 
    split='train[:10%]'
)

Reusing dataset wikitext (/Volumes/External HD/Dev/datasets/wikitext/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)


In [5]:
wikitext = load_dataset(
    'wikitext', 
    'wikitext-103-raw-v1', 
    cache_dir="/Volumes/External HD/Dev/datasets/wikitext", 
    split='train[:10%]'
)

random.seed(123)
passage_idxs = random.sample(range(1, 1e6), 60000)


sampleText = filterText(wikitext['text'][passage_idxs])
len(sampleText)

Reusing dataset wikitext (/Volumes/External HD/Dev/datasets/wikitext/wikitext/wikitext-103-raw-v1/1.0.0/47c57a6745aa5ce8e16a5355aaa4039e3aa90d1adad87cef1ad4e0f29e74ac91)


In [None]:
dp = DataProcessor(sampleText, write_dir='./data')
dp.keep_ents = ['PERSON']
dp.processEnts()
dp.permuteEnts()

In [6]:
dp

DataProcessor:<275 RAW><160 NER><160 PERM><1095 ENTS>

In [7]:
assert len(dp.permuted) == len(dp.raw_texts)

In [39]:

class Dataset(torch.utils.data.Dataset):
    def __init__(self, list_IDs, tokenizer, max_length=1024):
        'Initialization'
        self.list_IDs = list_IDs
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        
    def tokenize(self, textList):
        tokList = []
        for idx in range(len(textList)):
            tok = self.tokenizer(
                self.tokenizer.bos_token + 
                textList[idx] + 
                self.tokenizer.eos_token,
                truncation=True,
                max_length=self.max_length, 
                padding="max_length"
            )
            tokList.append(
                (
                    torch.tensor(tok['input_ids']), 
                    torch.tensor(tok['attention_mask'])
                )
            )
        return tokList
        

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]
        
        with open(f"data/original_entities.{ID}") as raw:
            raw_sample = raw.read()
        with open(f"data/permuted_entities.{ID}") as perm:
            permuted_sample = perm.read()
        
        raw, perm = self.tokenize([raw_sample, permuted_sample])

        return raw, perm

In [31]:

# class Dataset(torch.utils.data.Dataset):
#     def __init__(self, raw, permuted, tokenizer, max_length=1024):
#         'Initialization'
#         self.raw = raw
#         self.permuted = permuted
#         self.tokenizer = tokenizer
#         self.max_length = max_length
        
#         self.raw_tok = self.tokenize(self.raw)
#         self.permuted_tok = self.tokenize(self.permuted)
        
        
#     def tokenize(self, textList):
#         tokList = []
#         for idx in range(len(self.permuted)):
#             tok = self.tokenizer(
#                 self.tokenizer.bos_token + 
#                 textList[idx] + 
#                 self.tokenizer.eos_token,
#                 truncation=True,
#                 max_length=self.max_length, 
#                 padding="max_length"
#             )
#             tokList.append(
#                 (
#                     torch.tensor(tok['input_ids']), 
#                     torch.tensor(tok['attention_mask'])
#                 )
#             )
#         return tokList
        

#     def __len__(self):
#         'Denotes the total number of samples'
#         return len(self.permuted)

#     def __getitem__(self, index):
#         'Generates one sample of data'
#         # Select sample
#         raw_sample = self.raw_tok[index]
#         permuted_sample = self.permuted_tok[index]

#         return raw_sample, permuted_sample

In [40]:
dataset = Dataset(list(range(50)), tokenizer)

In [41]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=10
)

In [12]:
tokenizer(dp.permuted[0], truncation=True,
                max_length=100, 
                padding="max_length")

{'input_ids': [383, 983, 2540, 2478, 287, 3050, 837, 6872, 625, 257, 1588, 6903, 286, 262, 670, 1760, 319, 19391, 764, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [91]:
tokenizer(dp.raw_texts[0], truncation=True,
                max_length=100, 
                padding="max_length")

{'input_ids': [383, 983, 2540, 2478, 287, 3050, 837, 6872, 625, 257, 1588, 6903, 286, 262, 670, 1760, 319, 569, 18354, 7496, 17740, 2873, 764, 2893, 340, 17383, 262, 3210, 3033, 286, 262, 2168, 837, 340, 635, 25289, 3294, 16895, 837, 884, 355, 1642, 262, 983, 517, 43486, 329, 2168, 29661, 764, 15684, 11915, 371, 4548, 64, 8835, 73, 280, 290, 26777, 7286, 13704, 13231, 43354, 1111, 4504, 422, 2180, 12784, 837, 1863, 351, 569, 18354, 7496, 17740, 2873, 3437, 33687, 5303, 18024, 6909, 764, 317, 1588, 1074, 286, 8786, 12118, 262, 4226, 764, 383, 983, 705, 82, 4756, 7505, 373, 23568], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [43]:
i = 0
for train_step, (lm_data, edit_example) in enumerate(dataloader):
    edit_tokens, edit_mask = edit_example
    print(edit_tokens.size(), edit_mask.size())
#     model(edit_tokens,
#           attention_mask = edit_mask
#         )
    i += 1
    if i == 6:
        break

torch.Size([10, 1024]) torch.Size([10, 1024])
torch.Size([10, 1024]) torch.Size([10, 1024])
torch.Size([10, 1024]) torch.Size([10, 1024])
torch.Size([10, 1024]) torch.Size([10, 1024])
torch.Size([10, 1024]) torch.Size([10, 1024])


## Higher Loop

In [93]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [94]:
opt = torch.optim.Adam(model.parameters())
model.to(device)
model.train()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50258, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): Laye

In [46]:
for epoch in range(1):
    
    for train_step, (lm_data, edit_example) in enumerate(dataloader):
        lm_tokens, lm_mask = lm_data
        edit_tokens, edit_mask = edit_example
        print("unpacking")
        inner_opt = torch.optim.SGD(model.transformer.h[-3:].parameters(), lr=0.01)
        print("starting higher loop")
        with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
            
            for edit_step in range(1):
                print(f"edit step: {edit_step}")
#                 model(edit_tokens,attention_mask = edit_mask)
                loss = fmodel(edit_tokens, attention_mask=edit_mask).loss
                print("loss {loss}")
                diffopt.step(loss)

In [58]:
model.loss

ModuleAttributeError: 'GPT2LMHeadModel' object has no attribute 'loss'

In [61]:
total_epochs = 5

c2, c3 = 0.1, 0.1
n_edit_steps = 10

for epoch in range(total_epochs):
    
    for train_step, (lm_data, edit_example) in enumerate(dataloader):

        lm_tokens, lm_mask = lm_data
        lm_tokens, lm_mask = lm_tokens.to(device), lm_mask.to(device)
        edit_tokens, edit_mask = edit_example
        edit_tokens, edit_mask = edit_tokens.to(device), edit_mask.to(device)
        
        inner_opt = torch.optim.SGD(model.transformer.h[-3:].parameters(), lr=0.01)
        
        with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
            
            for edit_step in range(n_edit_steps):

                loss = fmodel(edit_tokens, attention_mask=edit_mask).loss
                diffopt.step(loss)
            # Now we have fmodel which is \hat{theta}
            # as well as model (the original model), which is just theta
            
            l_base = model(lm_data, attention_mask=lm_mask).loss
            l_edit = fmodel(edit_example, attention_mask=edit_mask).loss
            l_loc = F.kl_div(
                fmodel(lm_data).logits,
                model(lm_data).logits
            )
            
            total_loss = l_base + c2 * l_edit + c3 * l_loc
            total_loss.backward()
            
            print(f"Epoch: {epoch}; TrainStep {train_step}; L_edit {l_edit} L_base {l_base} L_loc {l_loc}; Total Loss {total_loss}")
    
    timestamp = datetime.now().strftime("%Y%m%d.%H.%m.%s")
    torch.save(model.state_dict(), f"models/model_epoch{epoch}.{timestamp}")
    torch.save(fmodel.state_dict(), f"models/fmodel_epoch{epoch}.{timestamp}")

IndexError: index out of range in self

## Evaluation