### GPT-2 모델 로딩하기

In [1]:
import torch
import torch.nn.functional as F

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

def sample_sequence(model, tokenizer, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
                    is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        for idx in range(length):

            inputs = {'input_ids': generated}
            if is_xlnet: 
                input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
                perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
                perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
                target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
                target_mapping[0, 0, -1] = 1.0  # predict last token
                inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}

            if is_xlm_mlm and xlm_mask_token:
                input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
                inputs = {'input_ids': input_ids}

            if xlm_lang is not None:
                inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)

            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for _ in set(generated[i].tolist()):
                    next_token_logits[i, _] /= repetition_penalty
                
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0: # greedy sampling:
                next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
    return generated


def generate(model, tokenizer, raw_text, length=20, temperature=1.0, top_k=0.0, top_p=0.9, device=torch.device('cpu')): 
    context_tokens = tokenizer.encode(raw_text)
    out = sample_sequence(
        model=model,
        tokenizer=tokenizer,
        context=context_tokens,
        length=length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        device=device,
        is_xlnet=False,
    )

    out = out[0, len(context_tokens):].tolist()
    text = tokenizer.decode(out, clean_up_tokenization_spaces=True)

    return text

In [2]:
import torch
import numpy as np
from tqdm import trange
from transformers import GPT2LMHeadModel, GPT2Tokenizer
#from generation import generate

In [3]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.device_count() > 0:
        torch.cuda.manual_seed_all(seed)

In [4]:
MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop
length = 20

temperature = 1.0
top_k = 0.0
top_p = 0.9

MODEL_CLASSES = {
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
}

In [5]:
device = torch.device('cpu')

In [6]:
if torch.cuda.is_available():
    device = torch.device('cuda')

In [7]:
seed = 42

In [8]:
model_type = 'gpt2'
model_name_or_path = 'gpt2'
model_class, tokenizer_class = MODEL_CLASSES[model_type]

In [9]:
model_class, tokenizer_class

(transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel,
 transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer)

In [10]:
tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
model = model_class.from_pretrained(model_name_or_path)
model.to(device)
model.eval()

Downloading vocab.json:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/523M [00:00<?, ?B/s]

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [11]:
set_seed(seed)

### generation test

In [12]:
prompt = f"Whether the Ministry of Gender Equality & Family should remain in existence again surfaced as a hot button issue Friday"
text = generate(model, tokenizer, prompt, device=device)

In [13]:
prompt + text

'Whether the Ministry of Gender Equality & Family should remain in existence again surfaced as a hot button issue Friday, the question that prompted further polarisation was "suspicions."\n\nThe human rights'

### GPT2를 이용한 few shot 러닝

GPT2 few shot 러닝을 위한 프롬프트 텍스트

In [15]:
prompt = 'play = played . sing = sang . view = viewed . act = acted . say = said . type = typed . note = noted . see = saw . clean = cleaned . tell = told . click ='

few shot 러닝을 테스트할 동사 단어들

In [16]:
verbs = ['click', 'work', 'walk', 'run', 'jump']

few shot 러닝에 사용되는 서포트셋과 프롬프트의 구성

```
# 10개의 서포트셋
play = played           # 1 shot
sing = sang             # 2 shot
view = viewed           # 3 shot
act = acted             # 4 shot
say = said              # 5 shot
type = typed            # 6 shot
note = noted            # 7 shot
see = saw               # 8 shot
clean = cleaned         # 9 shot
tell = told             # 10 shot
```

In [17]:
plurals = []
for verb in verbs:
    prompt = f"play = played . sing = sang . view = viewed . act = acted . say = said . type = typed . note = noted . see = saw . clean = cleaned . tell = told . {verb} ="
    plural = generate(model, tokenizer, prompt, device=device)
    plural = plural.split(".")[0]
    plurals.append(plural)

In [18]:
for v, p in zip(verbs, plurals):
    print(f'{v} -> {p}')

click ->  clicked
work ->  worked
walk ->  walked
run ->  run
jump ->  jump


마지막 jump 를 제외하면 단어의 원형과 과거형이 잘 매칭된 것을 확인할 수 있다.