<a href="https://colab.research.google.com/github/prakashsellathurai/GPT-2-Blogger/blob/main/gpt_2_large_torch_version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/2c/4e/4f1ede0fd7a36278844a277f8d53c21f88f37f3754abf76a5d6224f76d4a/transformers-3.4.0-py3-none-any.whl (1.3MB)
[K     |████████████████████████████████| 1.3MB 11.0MB/s 
Collecting tokenizers==0.9.2
[?25l  Downloading https://files.pythonhosted.org/packages/7c/a5/78be1a55b2ac8d6a956f0a211d372726e2b1dd2666bb537fea9b03abd62c/tokenizers-0.9.2-cp36-cp36m-manylinux1_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 49.3MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 44.5MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/e5/2d/6d4ca4bef9a67070fa1cac508606328329152b1df10bdf31fb6e4e727894/sentencepiece-0.1.94-cp36-cp36m-manylinux2014_x86_64.whl (1.1MB)


In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import argparse
import logging
from tqdm import trange

import torch
import torch.nn.functional as F
import numpy as np

from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer


logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

# ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())

MODEL_CLASSES = {
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
    'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
    'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
    'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
    'xlm': (XLMWithLMHeadModel, XLMTokenizer),
}

def set_seed(args):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
                    is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        for _ in trange(length):

            inputs = {'input_ids': generated}
            if is_xlnet: 
                # XLNet is a direct (predict same token, not next token) and bi-directional model by default
                # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
                input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
                perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
                perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
                target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
                target_mapping[0, 0, -1] = 1.0  # predict last token
                inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}

            if is_xlm_mlm and xlm_mask_token:
                # XLM MLM models are direct models (predict same token, not next token)
                # => need one additional dummy token in the input (will be masked and guessed)
                input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
                inputs = {'input_ids': input_ids}

            if xlm_lang is not None:
                inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)

            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)

            # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for _ in set(generated.view(-1).tolist()):
                next_token_logits[_] /= repetition_penalty
                
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0: #greedy sampling:
                next_token = torch.argmax(filtered_logits).unsqueeze(0)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
    return generated



In [None]:

model_name_or_path='gpt2-large'
model_type='gpt2'
n_gpu=0
no_cuda=False
padding_text=''
prompt=''
repetition_penalty=1.0
seed=42
stop_token=None
temperature=1.0
top_k=0
top_p=0.9


model_class, tokenizer_class = MODEL_CLASSES[model_type]
tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
model = model_class.from_pretrained(model_name_or_path)
model.to("cuda")
model.eval()


def continue_sentence(raw_text, length):
  context_tokens = tokenizer.encode(raw_text)
  out = sample_sequence(
      model=model,
      context=context_tokens,
      length=length,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      repetition_penalty=repetition_penalty,
      is_xlnet=bool(model_type == "xlnet"),  
      device="cuda"
  )
  out = out[0, len(context_tokens):].tolist()

  text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
  text = text[: text.find(stop_token) if stop_token else None]
  return text


10/30/2020 12:38:57 - INFO - filelock -   Lock 140171953019536 acquired on /root/.cache/torch/transformers/69f8d734111f39eaa51a85907bfdc81a7ef42242d638ffab6f77df305402b2b2.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…

10/30/2020 12:38:57 - INFO - filelock -   Lock 140171953019536 released on /root/.cache/torch/transformers/69f8d734111f39eaa51a85907bfdc81a7ef42242d638ffab6f77df305402b2b2.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71.lock
10/30/2020 12:38:57 - INFO - filelock -   Lock 140171953020432 acquired on /root/.cache/torch/transformers/38d28acc17953e356348dca948e152c653c0ccf5058a552eea30168e27f02046.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda.lock





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…

10/30/2020 12:38:57 - INFO - filelock -   Lock 140171953020432 released on /root/.cache/torch/transformers/38d28acc17953e356348dca948e152c653c0ccf5058a552eea30168e27f02046.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda.lock





10/30/2020 12:38:57 - INFO - filelock -   Lock 140171952808456 acquired on /root/.cache/torch/transformers/c8f887cdfff4327916f4b7ed06a379c0add42bd9c66e1fe3b4a5a8525a4b2678.7a56eb872b7d0abfad5ae7e76f318ac26b189fad23442b4017703dd0b946115a.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=764.0, style=ProgressStyle(description_…

10/30/2020 12:38:57 - INFO - filelock -   Lock 140171952808456 released on /root/.cache/torch/transformers/c8f887cdfff4327916f4b7ed06a379c0add42bd9c66e1fe3b4a5a8525a4b2678.7a56eb872b7d0abfad5ae7e76f318ac26b189fad23442b4017703dd0b946115a.lock
10/30/2020 12:38:57 - INFO - filelock -   Lock 140171934105272 acquired on /root/.cache/torch/transformers/eeb916d81211b381b5ca53007b5cbbd2f5b12ff121e42e938751d1fee0e513f6.999a50942f8e31ea6fa89ec2580cb38fa40e3db5aa46102d0406bcfa77d9142d.lock





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3247202234.0, style=ProgressStyle(descr…

10/30/2020 12:39:48 - INFO - filelock -   Lock 140171934105272 released on /root/.cache/torch/transformers/eeb916d81211b381b5ca53007b5cbbd2f5b12ff121e42e938751d1fee0e513f6.999a50942f8e31ea6fa89ec2580cb38fa40e3db5aa46102d0406bcfa77d9142d.lock





In [None]:


txt = """
Henlo hooman
"""

txt += continue_sentence(txt, 100)
  
print("\n" + txt)

100%|██████████| 100/100 [00:04<00:00, 20.42it/s]



Henlo hooman

Hitoki no Pin

I, Warrior Princess


Natsume Ver. Alias: Season 1

Original Title: Ookami Rin no Boushi-sama

Notes: First season began airing on June 6, 2008. First series illustrated by Kouhei Tanaka.

Natsume Ver. Character: Jun

Original Title: Natsume Erina no Shōjo Eru Fushigi na Dungeon de Shōwa Genroku Rakug



