## load tokenizer and model

In [2]:
import torch

from transformers import GPT2Tokenizer, GPT2LMHeadModel


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

model = GPT2LMHeadModel.from_pretrained("gpt2")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

outputs = model(**inputs, labels=inputs["input_ids"])

loss = outputs.loss

logits = outputs.logits[0]

## Functions

In [3]:
import torch.nn.functional as F

In [4]:
def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
                                    eod_mask_loss):
    """Build masks and position id for left to right model."""

    # Extract batch size and sequence length.
    micro_batch_size, seq_length = data.size()

    # Attention mask (lower triangular).
    if reset_attention_mask:
        att_mask_batch = micro_batch_size
    else:
        att_mask_batch = 1
    attention_mask = torch.tril(torch.ones(
        (att_mask_batch, seq_length, seq_length), device=data.device)).view(
            att_mask_batch, 1, seq_length, seq_length)

    # Loss mask.
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
        # Loop through the batches:
        for b in range(micro_batch_size):

            # Find indecies where EOD token is.
            eod_index = position_ids[b, data[b] == eod_token]
            # Detach indecies from positions if going to modify positions.
            if reset_position_ids:
                eod_index = eod_index.clone()

            # Loop through EOD indecies:
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                # Mask attention loss.
                if reset_attention_mask:
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
                # Reset positions.
                if reset_position_ids:
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
                    prev_index = i + 1

    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)

    return attention_mask, loss_mask, position_ids


def get_batch(context_tokens,tokenizer):
    """Generate batch from context tokens."""
    args_micro_batch_size = 1
    args_reset_position_ids = False
    args_reset_attention_mask= False
    args_eod_mask_loss= False
    tokenizer = tokenizer
    tokenizer_eod = tokenizer.encoder['<|endoftext|>']
    
    print(context_tokens[0])
    # Move to GPU.
    # tokens = context_tokens.view(args_micro_batch_size, -1).contiguous().cuda()
    # Get the attention mask and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        context_tokens,
        tokenizer_eod,
        args_reset_position_ids,
        args_reset_attention_mask,
        args_eod_mask_loss)

    return tokens, attention_mask, position_ids


In [17]:


def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ This function has been mostly taken from huggingface conversational
     ai code at
         https://medium.com/huggingface/how-to-build-a-state-of-the-art-
              conversational-ai-with-transfer-learning-2d818ac26313 """

    if top_k > 0:
        # Remove all tokens with a probability less than the
        # last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Cconvert to 1D
        sorted_logits, sorted_indices = torch.sort(
            logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
                                        dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] \
            = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits[i][indices_to_remove] = filter_value

    return logits


def generate_samples_input_from_file(model):

    args = get_args()
    tokenizer = get_tokenizer()

    # Read the sample file and open the output file.
    assert args.sample_input_file is not None, \
        'sample input file is not provided.'
   
    fname = open(args.sample_input_file, "r")
    all_raw_text = fname.readlines()
    input_count = len(all_raw_text)
    input_pos = 0
    if args.sample_output_file is None:
        sample_output_file = args.sample_input_file + ".out"
        print('`sample-output-file` not specified, setting '
                  'it to {}'.format(sample_output_file))
        
    fname_out = open(sample_output_file, "w+")

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            terminate_runs = 0
            raw_text_len = 0
            raw_text = all_raw_text[input_pos]
            input_pos += 1
                            
            if input_pos == input_count:
                raw_text = "stop"
                raw_text_len = len(raw_text)

            if "stop" in raw_text:
                terminate_runs = 1
            else:
                context_tokens = tokenizer.tokenize(raw_text)
                context_length = len(context_tokens)

  

            # input_info = [terminate_runs, raw_text_len, context_length]
            # input_info_tensor = torch.cuda.LongTensor(input_info)
            # torch.distributed.all_reduce(input_info_tensor,
            #                              group=mpu.get_model_parallel_group())
            # terminate_runs = input_info_tensor[0].item()
            # raw_text_len = input_info_tensor[1].item()
            # context_length = input_info_tensor[2].item()

            if terminate_runs == 1:
                return

            # # For pipeline parallel we send context tokens to other stages
            # # so they get the lengths correct
            # if mpu.get_tensor_model_parallel_rank() == 0 \
            #    and args.pipeline_model_parallel_size > 1:
            #     if mpu.is_pipeline_first_stage():
            #         src = mpu.get_pipeline_model_parallel_first_rank()
            #         group = mpu.get_pipeline_model_parallel_group()
            #         context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
            #         torch.distributed.broadcast(context_tokens_tensor, src, group)
            #     else:
            #         src = mpu.get_pipeline_model_parallel_first_rank()
            #         group = mpu.get_pipeline_model_parallel_group()
            #         context_tokens_tensor = torch.empty(context_length,
            #                                             dtype=torch.int64,
            #                                             device=torch.device("cuda"))
            #         torch.distributed.broadcast(context_tokens_tensor, src, group)
            #         context_tokens = context_tokens_tensor.cpu().numpy().tolist()

            token_stream = get_token_stream(model, [context_tokens])
            for _, decode_tokens in enumerate(token_stream):
                pass

#             if mpu.get_tensor_model_parallel_rank() == 0:
#                 if mpu.is_pipeline_first_stage():
#                     os.system('clear')
#                     print("\nContext:", raw_text, flush=True)

            fname_out.write("\nContext:")
            fname_out.write(raw_text)

            decode_tokens, _ = decode_tokens
            decode_tokens = decode_tokens[0].cpu().numpy().tolist()
            trim_decode_tokens = tokenizer.decode(
                        decode_tokens)[raw_text_len:]
            print("\nMegatron-LM:", trim_decode_tokens, flush=True)

            fname_out.write("\n\nMegatron-LM:")
            fname_out.write(trim_decode_tokens)
            fname_out.write("\n")

            raw_text = None
            context_count += 1


def generate_samples_interactive(tokenizer,model, print_frequency=24):

    args_seq_length = 1024
    tokenizer = tokenizer

    context_count = 0
    model.eval()
    with torch.no_grad():
        while True:
            terminate_runs = 0
            raw_text_len = 0

            
            os.system('clear')
            raw_text = input("\nContext prompt (stop to exit) >>> ")
            while not raw_text:
                print('Prompt should not be empty!')
                raw_text = input("\nContext prompt (stop to exit) >>> ")
            raw_text_len = len(raw_text)

            if "stop" in raw_text:
                terminate_runs = 1
            else:
                context_tokens = tokenizer(raw_text)['input_ids']

                context_length = len(context_tokens)
                # print(context_length)
                if context_length >= (args_seq_length // 2):
                    print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the "
                              "sequence length)!", flush=True)
                    continue
            
            input_info = [terminate_runs, raw_text_len, context_length]
            input_info_tensor = torch.LongTensor(input_info)
            terminate_runs = input_info_tensor[0].item()
            raw_text_len = input_info_tensor[1].item()
            context_length = input_info_tensor[2].item()

            if terminate_runs == 1:
                return

            # For pipeline parallel we send context tokens to other stages
            # so they get the lengths correct
            
          
            context_tokens_tensor = torch.LongTensor(context_tokens)
               
            token_stream = get_token_stream(tokenizer,model, [context_tokens])

            for counter, decode_tokens in enumerate(token_stream):
                if counter % print_frequency != 0:
                    continue

                # os.system('clear')
                # print("\nContext:", raw_text, flush=True)

                # decode_tokens, _ = decode_tokens
                # decode_tokens = decode_tokens[0].cpu().numpy().tolist()
                # trim_decode_tokens = tokenizer.decode(
                #     decode_tokens)[raw_text_len:]
                # print("\ngpt:", trim_decode_tokens, flush=True)

           
            os.system('clear')
            print("\nContext:", raw_text, flush=True)

            if not isinstance(decode_tokens, list):
                decode_tokens, _ = decode_tokens
                decode_tokens = decode_tokens[0].cpu().numpy().tolist()
            trim_decode_tokens = tokenizer.decode(
                    decode_tokens)[raw_text_len:]
            print("\nGPT:", trim_decode_tokens, flush=True)

            input("\nPress Enter to continue >>>")

            raw_text = None
            context_count += 1



def generate_samples_unconditional(model):

    args = get_args()
    tokenizer = get_tokenizer()

    num_samples = args.num_samples
    context_tokens = [[tokenizer.eod] for _ in range(args.micro_batch_size)]
    ctr = 0
    while True:
        start_time = time.time()
        for token_stream in get_token_stream(model,
                                             copy.deepcopy(context_tokens)):
            pass
        if mpu.is_pipeline_last_stage() and \
           mpu.get_tensor_model_parallel_rank() == 0:
            if ctr % args.log_interval == 0:
                print('Avg s/batch:',(time.time() - start_time) / min(args.log_interval, ctr + 1))
                start_time = time.time()
            length = len(token_stream)
            token_batch = token_stream[0].cpu().numpy().tolist()
            length_batch = token_stream[1].cpu().numpy().tolist()
            assert len(length_batch) == args.micro_batch_size
            for tokens, length in zip(token_batch, length_batch):
                tokens = tokens[1:length - 1]
                text = tokenizer.detokenize(tokens)
                is_finished = length < args.seq_length - 1
                datum = {'text': text, 'length': length - 1, 'finished': is_finished}
                yield datum
                ctr += 1
                if ctr >= num_samples:
                    break
        else:
            for _ in range(args.micro_batch_size):
                yield None
                ctr += 1
                if ctr >= num_samples:
                    break
        if ctr >= num_samples:
            break


def generate_and_write_samples_unconditional(model):

    args = get_args()
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
        for datum in generate_samples_unconditional(model):
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')


def pad_batch(batch, pad_id, args_seq_length=1024):

    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
        # print(context_length)
        if context_length < args_seq_length:
            tokens.extend([pad_id] * (args_seq_length - context_length))
        context_lengths.append(context_length)
        # print(context_length)
    return batch, context_lengths


def get_token_stream(tokenizer,model, context_tokens):


    tokenizer = tokenizer

    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer_eod, args_seq_length=1024)

    context_tokens_tensor = torch.LongTensor(context_tokens)
    context_length_tensor = torch.LongTensor(context_lengths)

    context_length = context_length_tensor.min().item()
    # tokens, attention_mask, position_ids = get_batch(context_tokens,tokenizer)

    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,tokenizer)
                                                 
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        if tokens is not None:
            yield tokens[:context_length], lengths
        else:
            yield None, None


def switch(val1, val2, boolean):

    boolean = boolean.type_as(val1)
    boolean = int(boolean)
    return (1 - boolean) * val1 + boolean * val2


def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,layer_past=None, get_key_value=None,forward_method_parallel_output=None):

    # Hidden size changes when not using recompute, need to tell communicate()
    # the correct size
    args = get_args()
    orig_seq_length = 1024
    args.seq_length = tokens.shape[1]

#     output_tensor = model(tokens, position_ids, attention_mask,
#                                   tokentype_ids=tokentype_ids,
#                                   layer_past=layer_past)
    output_tensor = model(tokens[0])  
    args.seq_length = orig_seq_length
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor,layer_past


def sample_sequence_batch(model, context_tokens, context_lengths,tokenizer,maxlen=None, type_ids=None):

    args_out_seq_length = 1024
    tokenizer = tokenizer

    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()
        eos_id = tokenizer_eod

        counter = 0
        org_context_length = context_length

        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte()
        tokens = context_tokens
        if maxlen is None:
            maxlen = 1024 - 1
            if maxlen > (org_context_length + args_out_seq_length):
                maxlen = org_context_length + args_out_seq_length

        lengths = torch.ones([batch_size]).long() * maxlen

        while context_length <= (maxlen):
#             if args.recompute:
#                 output,layer_past = forward_step(model, tokens,
#                                       position_ids,
#                                       attention_mask,
#                                       tokentype_ids=type_ids,
#                                       get_key_value=True,
#                                       forward_method_parallel_output=False)
               
#                 logits = output[:, context_length - 1, :]
            if True:
                types2use = None
                if counter == 0:
                    tokens2use = tokens[:, :context_length]
                    # positions2use = position_ids[:, :context_length]
                    if type_ids is not None:
                        types2use = type_ids[:, :context_length]
                else:
                    tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
                    # positions2use = position_ids[:, context_length - 1].view(batch_size, -1)
                    if type_ids is not None:
                        types2use = type_ids[:, context_length - 1].view(batch_size, -1)
                # output, layer_past = forward_step(model, tokens2use,
                #                                   positions2use,
                #                                   attention_mask,
                #                                   layer_past=layer_past,
                #                                   get_key_value=True,
                #                                   tokentype_ids=types2use,
                #                                   forward_method_parallel_output=False)
                # print(context_tokens)
                # print('chk')
                output = model(tokens2use).logits
                logits = output[:, -1].view(batch_size, -1).contiguous()

            if True:
                if True:
                    # logits = logits.float()
                    # logits /= args.temperature
                    logits = top_k_logits(logits, top_k=0,top_p=0.9)
                                          
                    log_probs = F.softmax(logits, dim=-1)
                    prev = torch.multinomial(log_probs, num_samples=1).view(1,-1)
                    

                started = context_lengths <= context_length
                # print(tokens[0])
                # print(prev)
                # print(started)                             
                new_tokens = switch(tokens[:, context_length].view(-1), prev, started)
                # print(new_tokens)
                tokens[:, context_length] = new_tokens
                # src = mpu.get_pipeline_model_parallel_last_rank()
                # group = mpu.get_embedding_group()
                # torch.distributed.broadcast(new_tokens, src, group)

                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                # src = mpu.get_pipeline_model_parallel_last_rank()
                # group = mpu.get_pipeline_model_parallel_group()
                # torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

#             else:
#                 if mpu.is_pipeline_first_stage():
#                     src = mpu.get_pipeline_model_parallel_last_rank()
#                     group = mpu.get_embedding_group()
#                     new_tokens = torch.empty_like(tokens[:, context_length])
#                     torch.distributed.broadcast(new_tokens, src, group)
#                     tokens[:, context_length] = new_tokens
#                     yield tokens, None
#                 else:
#                     yield None, None

#                 done = torch.cuda.ByteTensor([0])
#                 src = mpu.get_pipeline_model_parallel_last_rank()
#                 group = mpu.get_pipeline_model_parallel_group()
#                 torch.distributed.broadcast(done, src, group)

            context_length += 1
            counter += 1
            if done:
                break

In [18]:
tokens = inputs["input_ids"]
tokenizer_eod = tokenizer.encoder['<|endoftext|>']
logits = logits
import os
generate_samples_interactive(tokenizer,model)


Context prompt (stop to exit) >>>  Inside his own bunker, the President has a habit of staring at his daily agenda even when the day is over. He lies awake and wonders whether he missed something, forgot someone. “It’s pointless,” Volodymyr Zelensky told me at the presidential compound in Kyiv, just outside the office where he sometimes sleeps. “It’s the same agenda.



Context: Inside his own bunker, the President has a habit of staring at his daily agenda even when the day is over. He lies awake and wonders whether he missed something, forgot someone. “It’s pointless,” Volodymyr Zelensky told me at the presidential compound in Kyiv, just outside the office where he sometimes sleeps. “It’s the same agenda.

GPT:  But's and two model listed on ::
Enter.vusal https://mys- as controversy of offer - 03.2017 back managers right of researcher because contemporary role."
Layr, squ 400 will on 14, 46- two at the
* 10 arrived with its majority consider banks mobile investing the benefits of the own "res discussed their neighbors the platform was discussed " I set for meeting in the 2018: 12 ronger stayed for partners to instead he experts over ten rules, actually nearly two of the cry this vote/vision, Its lot rules it we
Online."13. Happ's, constant
Please them he and Mar drug if the strings to do. Cause". Cost more through you to the Entertainment
EXT. web


Press Enter to continue >>> stop

Context prompt (stop to exit) >>>  stop
