In [33]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm, trange
import torch.nn.functional as F
from transformers import AdamW
from transformers import get_scheduler

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [22]:
# dataset for MR pairs
    # Q: why are msg tokens masked in labels? A?: so that the prediction results there are not factored into the loss computation?

class MRDataset(Dataset):
    def __init__(self, data_path, line_transform):
        self.data_path = data_path
        self.line_transform = line_transform
        
        self.data_file = open(self.data_path, 'r', encoding='utf-8') # open data file
        self.data_offset_map = [] # element at index if byte offset of that line index
        self.create_file_offset()
    
    def __len__(self):
        return len(self.data_offset_map) # offset map has byte offsets for each line
    
    def __getitem__(self, idx):
        offset = self.data_offset_map[idx]
        self.data_file.seek(offset, 0)
        line = self.data_file.readline()
        
        # this is where nick applies a series of transforms (wise bc the the series of transforms applied may vary depending on the kind of data the model needs)
        return self.line_transform(line) # transform should return msg_ids, rsp_ids in our case
        
    def create_file_offset(self):
        """Maps lines in files to the byte offset they're located at in the file, to enable shifting to that byte when reading line"""
        
        with open(self.data_path, 'rb') as fh:
            # TODO: if already appending a 0 for first line, should for loop start from second line (or do line indexes start from 1)
            self.data_offset_map.append(0)  # set the first offset to zero position for a new file
            for _ in fh:
                # Checks whether we have reached the end of the file or not
                # fh.fileno returns the integer id of file_descriptor,
                # fstat returns info about the file, and
                # st_size gets the file_size in bytes
                if not fh.tell() == os.fstat(fh.fileno()).st_size:
                    # Adds the current byte offset to the map
                    self.data_offset_map.append(fh.tell()) 

In [23]:
class GetMRIds(object):
    def __init__(self, msg_col, reply_col, delimiter, tokenizer):
        self.msg_col = msg_col
        self.reply_col = reply_col
        self.delimiter = delimiter
        self.tokenizer = tokenizer
        
    def __call__(self, line):
        cols = line.split(self.delimiter) # split line into cols based on delimiter
        # get message and reply
        msg = cols[self.msg_col] 
        reply = cols[self.reply_col]
        
        msg_ids = self.tokenizer(msg)['input_ids']
        rsp_ids = self.tokenizer(reply)['input_ids']
        
        return msg_ids, rsp_ids # TODO: does this need to be a tuple explicitly (implicitly already treated as a tuple?)?

In [24]:
class CollateMRPairs(object):
    def __init__(self, max_msg_len, max_reply_len, pad_token, ignore_token):
        self.max_msg_len = max_msg_len
        self.max_reply_len = max_reply_len
        self.pad_token = pad_token
        self.ignore_token = ignore_token
    
    def __call__(self, batch):
        """ Args:
                batch [batch_size, ()]: 
            return: input_ids, labels 
        """
        batch_msg_ids, batch_reply_ids = list(zip(*batch)) # unzip batch list of form [(msg_ids_1, rsp_ids_1)..(msg_ids_2, rsp_ids_2)] to list of msg ids, list of rsp ids
        batch_size = len(batch_msg_ids)
        
        # input ids (batch_size, max_msg_len + max_reply_len) 
        # (could this technically be max_msg_len + max_reply_len - 1, since last reply token not used for generation?)
        # want to pad the input ids vector, in cases where not at max length for msg_ids + reply_ids
        input_ids = np.full([batch_size, self.max_msg_len + self.max_reply_len], self.pad_token, dtype=np.long)
        print(input_ids.shape)
        
        # labels (batch_size, max_msg_len + max_reply_len) 
        # don't care about labels for all msg tokens except last (generates first reply token)
        # or for last reply token (no next token to generate)
        labels = np.full([batch_size, self.max_msg_len + self.max_reply_len], self.ignore_token, dtype=np.long)
        
        # for each instance of batch (MR pair), fill in the input_ids and labels to GPT-2 model
        for i in range(batch_size):
            msg_ids = batch_msg_ids[i]
            reply_ids = batch_reply_ids[i]
            
            msg_ids_len = min(len(msg_ids), max_msg_len)
            reply_ids_len = min(len(reply_ids), max_reply_len)
            
            # tokens 0 ... msg_ids_len - 1 (total msg_ids_len tokens) are taken up by msg_ids
            input_ids[i, :msg_ids_len] = msg_ids[:msg_ids_len]
            
            # tokens msg_ids_len ... msg_ids_len + reply_ids_len - 2 (total reply_ids_len - 1 tokens) are taken up by all reply_ids except last 
            # (last reply token is not used in generation, since there is no following token)
            input_ids[i, msg_ids_len:(msg_ids_len + reply_ids_len - 1)] = reply_ids[:(reply_ids_len - 1)]
            
            # don't care about labels for any msg token, outside last (don't care about 0-msg_ids_len-2)
            # no label for last reply token either (no next token generated)
            labels[i, (msg_ids_len - 1):(msg_ids_len - 1 + reply_ids_len)] = reply_ids[:reply_ids_len]
            
        return input_ids, labels

In [25]:
def create_dataset(data_path, msg_col, reply_col, delimiter, tokenizer):
    line_transform = GetMRIds(msg_col, reply_col, delimiter, tokenizer)
    return MRDataset(data_path, line_transform)

In [26]:
def create_dataloader(data_path, msg_col, reply_col, delimiter, max_msg_len, max_reply_len, \
          pad_token, ignore_token, batch_size, num_workers):
    gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # what is the difference between GPT2TokenizerFast vs Tokenizer?
    
    mr_dataset = create_dataset(data_path, msg_col, reply_col, delimiter, gpt2_tokenizer)
    mr_collate_fn = CollateMRPairs(max_msg_len, max_reply_len, pad_token, ignore_token)
    mr_dataloader = DataLoader(mr_dataset, batch_size=batch_size, collate_fn=mr_collate_fn, num_workers=num_workers)
    
    return mr_dataloader

In [27]:
data_path = "./data/hundred_reddit_mr_pairs.tsv"
msg_col = 0
reply_col = 1
delimiter = '\t'
max_msg_len = 10
max_reply_len = 10
pad_token = 0
ignore_token = -100
batch_size = 10
num_workers = 0
mr_dataloader = create_dataloader(data_path, msg_col, reply_col, delimiter, max_msg_len, max_reply_len, \
                                  pad_token, ignore_token, batch_size, num_workers)

In [28]:
def sanity_check_inputids_labels(max_msg_len, max_rsp_len, pad_token, ignore_token, generated_input_ids, generated_labels):
    gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    line = "This works both ways too. If the US wanted to join Canada, we would find a way to accept you too.	No thanks.	180"
    cols = line.split("\t")
    print("cols:", cols)
    print("\n")
    
    msg = cols[0]
    reply = cols[1]
    print("msg:", msg)
    print("reply:", reply)
    print("\n")
    
    msg_ids = gpt2_tokenizer(msg)['input_ids']
    reply_ids = gpt2_tokenizer(reply)['input_ids']
    print("msg_ids:", msg_ids)
    print("reply_ids:", reply_ids, "\n")
    
    msg_len = min(max_msg_len, len(msg_ids))
    reply_len = min(max_reply_len, len(reply_ids))
    
    print("msg_ids truncated by max:", msg_ids[:max_msg_len])
    print("reply_ids truncated by max:", reply_ids[:max_rsp_len], "\n")
    
    input_ids = np.full(max_msg_len + max_rsp_len, pad_token)
    print("input ids - initial\n", input_ids, "\n")
    input_ids[:msg_len] = msg_ids[:msg_len]
    print("input ids - add in msg ids\n", input_ids, "\n")
    input_ids[msg_len: (msg_len + reply_len - 1)] = reply_ids[:(reply_len-1)]
    print("input ids - add in rsp ids (except last)\n", input_ids, "\n")
    
    labels = np.full(max_msg_len + max_rsp_len, ignore_token)
    print("labels - initial\n", labels, "\n")
    labels[msg_len-1: msg_len-1+reply_len] = reply_ids[:reply_len]
    print("labels - adding in reply labels\n", labels, "\n")
    
    assert(np.array_equal(input_ids, generated_input_ids))
    assert(np.array_equal(labels, generated_labels))

In [39]:
def train_model():
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model.to(device)
    
    optimizer = AdamW(model.parameters(), lr=5e-5)
    num_epochs = 3
    num_training_steps = num_epochs * len(mr_dataloader)
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps
    )
    
    progress_bar = tqdm(range(num_training_steps))
    
    model.train()
    # training loop
    for epoch in range(num_epochs):
        for step, batch in enumerate(mr_dataloader):
            input_ids, labels = batch
            #if step == 0:
                #sanity_check_inputids_labels(10, 10, -100, -100, input_ids[0], labels[0])

            input_ids = torch.tensor(input_ids, dtype=torch.long, device=device)
            print("input_ids size:", input_ids.shape)
            labels = torch.tensor(labels, dtype=torch.long, device=device)
            print("labels size:", labels.shape)

            output_dict = model(input_ids)
            logits = output_dict["logits"]
            print("logits size:", logits.shape)

            valid_token_mask = labels != ignore_token
            flat_logits1 = logits[valid_token_mask, ...]
            flat_labels1 = labels[valid_token_mask]
            mean_loss1 = F.cross_entropy(flat_logits1, flat_labels1, ignore_index=-100)
            print("mean loss1:", mean_loss1)

            # flatten logits across batches 
            flat_logits2 = logits.view(-1, logits.size(-1))
            flat_labels2 = labels.view(-1)
            mean_loss2 = F.cross_entropy(flat_logits2, flat_labels2, ignore_index=-100)
            print("mean loss2:", mean_loss2)
            mean_loss2.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            
            break

        
train_model()

  0%|                                                                                           | 0/30 [00:00<?, ?it/s]

(10, 20)
input_ids size: torch.Size([10, 20])
labels size: torch.Size([10, 20])
logits size: torch.Size([10, 20, 50257])
mean loss1: tensor(6.4648, grad_fn=<NllLossBackward>)
mean loss2: tensor(6.4648, grad_fn=<NllLossBackward>)


  3%|██▊                                                                                | 1/30 [00:03<01:29,  3.09s/it]

(10, 20)
input_ids size: torch.Size([10, 20])
labels size: torch.Size([10, 20])
logits size: torch.Size([10, 20, 50257])
mean loss1: tensor(5.4651, grad_fn=<NllLossBackward>)
mean loss2: tensor(5.4651, grad_fn=<NllLossBackward>)


  7%|█████▌                                                                             | 2/30 [00:06<01:29,  3.19s/it]

(10, 20)
input_ids size: torch.Size([10, 20])
labels size: torch.Size([10, 20])
logits size: torch.Size([10, 20, 50257])
mean loss1: tensor(4.6162, grad_fn=<NllLossBackward>)
mean loss2: tensor(4.6162, grad_fn=<NllLossBackward>)


 10%|████████▎                                                                          | 3/30 [00:09<01:29,  3.33s/it]
