Set params

In [1]:
# choose embedding dimension = 128
embedding_dim = 128
hidden_dim = 128 
seq_len = 128 
output_dims = 128
num_diffusion_timesteps = 2000
lr=1e-04
batch_size = 20 
ema_rate = 0.999
weight_decay = 0.01
learning_steps = 2000

Set GPU as device

In [2]:
import torch
# use GPU if available
is_cuda = torch.cuda.is_available()

if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU not available, CPU used


Import bert tokenizer

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size
print(vocab_size)


Define embedding function

In [None]:
model_emb = torch.nn.Embedding(tokenizer.vocab_size, embedding_dim)

# initialize random embeddings
torch.nn.init.normal_(model_emb.weight)

Load sample data (for integration testing to make sure that code can run properly)

In [None]:
import json

data_dir = "./datasets/sample"
path = f'{data_dir}/train.jsonl'

sentence_lst = {'src':[], 'trg': []}
with open(path, 'r') as f_reader:
        for row in f_reader:
            content = json.loads(row)
            sentence_lst['src'].append(content['src'].strip())
            sentence_lst['trg'].append(content['trg'].strip())

Tokenize

In [None]:
from datasets import Dataset
raw_datasets = Dataset.from_dict(sentence_lst)

def tokenize_function(examples):
        input_id_x = tokenizer(examples['src'], add_special_tokens=True)['input_ids']
        input_id_y = tokenizer(examples['trg'], add_special_tokens=True)['input_ids']
        result_dict = {'input_id_x': input_id_x, 'input_id_y': input_id_y}
        return result_dict

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=['src', 'trg'],
    load_from_cache_file=True,
    desc="Running tokenizer on dataset",
)

In [None]:
# helper function to collate the batch
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
    result = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
    mask_ = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
    for i, example in enumerate(examples):
        curr_len = min(len(example), max_length)
        result[i][:curr_len] = example[:curr_len]
        mask_[i][:curr_len] = [1] * curr_len
    if return_mask:
        return result, mask_
    return result

def merge_and_mask(group_lst):
        lst = []
        mask = []
        for i in range(len(group_lst['input_id_x'])):
            end_token = group_lst['input_id_x'][i][-1]
            src = group_lst['input_id_x'][i][:-1]
            trg = group_lst['input_id_y'][i][:-1]
            while len(src) + len(trg) > seq_len - 3:
                if len(src)>len(trg):
                    src.pop()
                elif len(src)<len(trg):
                    trg.pop()
                else:
                    src.pop()
                    trg.pop()
            src.append(end_token)
            trg.append(end_token)

            lst.append(src + [tokenizer.sep_token_id] + trg)
            mask.append([0]*(len(src)+1))
        group_lst['input_ids'] = lst
        group_lst['input_mask'] = mask
        return group_lst

tokenized_datasets = tokenized_datasets.map(
        merge_and_mask,
        batched=True,
        num_proc=1,
        desc=f"merge and mask",
    )
    
def pad_function(group_lst):
    max_length = seq_len
    group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], tokenizer.pad_token_id, max_length)
    group_lst['input_mask'] = _collate_batch_helper(group_lst['input_mask'], 1, max_length)
    return group_lst

lm_datasets = tokenized_datasets.map(
        pad_function,
        batched=True,
        num_proc=1,
        desc=f"padding",
    )


In [None]:
import datasets
raw_datasets = datasets.DatasetDict()
raw_datasets['train'] = lm_datasets

Text dataset

In [None]:
from torch.utils.data import Dataset
import torch as th
import numpy as np

class TextDataset(Dataset):
    def __init__(self, text_datasets, model_emb=None):
        super().__init__()
        self.text_datasets = text_datasets
        self.length = len(self.text_datasets['train'])
        self.model_emb = model_emb

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with torch.no_grad():

            input_ids = self.text_datasets['train'][idx]['input_ids']
            hidden_state = self.model_emb(torch.tensor(input_ids))

            # obtain the input vectors, only used when word embedding is fixed (not trained end-to-end)
            arr = np.array(hidden_state, dtype=np.float32)

            out_kwargs = {}
            out_kwargs['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
            out_kwargs['input_mask'] = np.array(self.text_datasets['train'][idx]['input_mask'])

            return arr, out_kwargs

Data loader

In [None]:
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

train_dataset = TextDataset(raw_datasets, model_emb=model_emb)

data_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
)

data = iter(data_loader)