In [13]:
import yaml
from gpt2 import ruGPT2
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn.functional as F

In [14]:
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
gpt2 = GPT2LMHeadModel.from_pretrained(model_name)

In [18]:
class Config:
    def __init__(self, file_path):
        with open(file_path, 'r') as file:
            config_data = yaml.safe_load(file)
        self._dict_to_object(config_data, self)

    def _dict_to_object(self, d, obj):
        for key, value in d.items():
            if isinstance(value, dict):
                setattr(obj, key, type('ConfigObject', (), {}))
                self._dict_to_object(value, getattr(obj, key))
            else:
                setattr(obj, key, value)

config = Config('config.yaml')                
rugpt2 = ruGPT2(config)

In [19]:
state_dict = gpt2.state_dict()

renamed_state_dict = {}
for key in state_dict.keys():
    new_key = key
    if 'transformer' in new_key:
        new_key = new_key.replace('transformer.', '')
    renamed_state_dict[new_key] = state_dict[key]

In [21]:
rugpt2.load_state_dict(renamed_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['h.0.attn.bias', 'h.1.attn.bias', 'h.2.attn.bias', 'h.3.attn.bias', 'h.4.attn.bias', 'h.5.attn.bias', 'h.6.attn.bias', 'h.7.attn.bias', 'h.8.attn.bias', 'h.9.attn.bias', 'h.10.attn.bias', 'h.11.attn.bias'], unexpected_keys=[])

In [38]:
def top_k_logits(logits, top_k=50):
    if top_k > 0:
        values, _ = torch.topk(logits, top_k)
        min_values = values[:, -1]
        logits = torch.where(logits < min_values, torch.ones_like(logits) * -float('inf'), logits)
    return logits


def custom_generate(model, input_ids, max_length=10, temperature=1.0, top_k=50, no_repeat_n_gram_size=2):
    model.eval()
    with torch.no_grad():
        for _ in range(max_length):
            logits = model(input_ids)[:, -1, :] / temperature
            filtered_logits = top_k_logits(logits, top_k=top_k)
            
            if no_repeat_n_gram_size > 0:
                for _ in range(no_repeat_n_gram_size):
                    if input_ids.shape[1] > no_repeat_n_gram_size:
                        logits[:, input_ids[:, -no_repeat_n_gram_size:]] = -float('inf')
            
            probabilities = F.softmax(filtered_logits, dim=-1)
            next_token = torch.multinomial(probabilities, 1)
            input_ids = torch.cat([input_ids, next_token], dim=-1)

    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

    return generated_text

prompt = "Once upon a time I met a"
input_ids = tokenizer.encode(prompt, return_tensors='pt')
generated_text = custom_generate(rugpt2, input_ids=input_ids)
print(generated_text.replace('\n', ''))

Once upon a time I met aField who understood an age an who understood how the
