In [1]:
import torch, yaml
from Model import GPT2S
import torch.nn.functional as F
from encoder import get_encoder
from tqdm import trange

In [2]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

with open("config.yaml") as f:
    config = AttrDict(yaml.load(f))

  config = AttrDict(yaml.load(f))


In [3]:
model = GPT2S(config)
model

GPT2S(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm((768,), eps=5e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
        )
        (ln_2): LayerNorm((768,), eps=5e-05, elementwise_affine=True)
        (mlp): FeedForward(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): GELU(approximate='none')
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=5e-05, elementwise_affine=True)
  )
  (lm_head): LMHead(
    (decoder): Linear(in_features=768, out_features=50257, bias=False)
  )
)

In [4]:
def load_weight(model, state_dict):
    old_keys = []
    new_keys = []
    for key in state_dict.keys():
        new_key = None
        if key.endswith(".g"):
            new_key = key[:-2] + ".weight"
        elif key.endswith(".b"):
            new_key = key[:-2] + ".bias"
        elif key.endswith(".w"):
            new_key = key[:-2] + ".weight"
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
        )
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + ".")

    start_model = model
    if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
        start_model = model.transformer
    load(start_model, prefix="")

    # Make sure we are still sharing the output and input embeddings after loading weights
    model.set_tied()
    return model

In [5]:
def top_k_logits(logits, k):
    if k == 0:
        return logits
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1]
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)

def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, sample=True):
    context = torch.full((batch_size, 1), start_token, dtype=torch.long)
    prev = context
    output = context
    past = None
    with torch.no_grad():
        for i in trange(length):
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_k_logits(logits, k=top_k)
            log_probs = F.softmax(logits, dim=-1)
            if sample:
                prev = torch.multinomial(log_probs, num_samples=1)
            else:
                _, prev = torch.topk(log_probs, k=1, dim=-1)
            output = torch.cat((output, prev), dim=1)
    return output

In [6]:
state_dict = torch.load('pytorch_model.bin', map_location='cpu' if not torch.cuda.is_available() else None)
tokenizer = get_encoder()
model = load_weight(model, state_dict)

In [7]:
text = "Just tell me anything"

In [8]:
context_tokens = tokenizer.encode(text)
generated = 0
for _ in range(1):
    out = sample_sequence(
            model=model, length=256,
            context=context_tokens,
            start_token=tokenizer.encoder['<|endoftext|>'],
            batch_size=1
        )
    out = out[:, len(context_tokens):].tolist()
    for i in range(1):
        generated += 1
        text = tokenizer.decode(out[i])
        print(text)

100%|██████████| 256/256 [00:11<00:00, 22.47it/s]

catcher

After pre-registration, any precompensated members who have photo identification will be presented with online questions and will be notified of their initial biographic confirmation. When nominations are completed, the party chair will arrange to solicit submissions once an administrative appeal is completed; if the appeal is pending, the chair will explain the reasons behind its conduct as: Insurance technology doesn't work; Government is concerned; A sponsor is sponsoring this party who is compromised or financially compromised.

Registered non-members are responsible for making sure the Canadian Taxpayers Federation is in compliance to the Canadian IRS Electronic System so that the Commissioners cannot correct any classification duces d'elys.[1]

A reply to the Commissioner's letter.

Letter mark and gilded memberhips

Available alumnae for re-election are fully repayable. Financial reserves exceeding $20,000 are granted.

Telephone tickets.

Men can attend any public (inc


