# Head

In [78]:
import torch, importlib
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import GPT2Tokenizer, DistilBertTokenizer, GPT2LMHeadModel
import matplotlib.pyplot as plt

In [79]:
prompt = "\n"
sentence = prompt + open('Genshin Impact.txt').read().replace('\n', ' ')

model_name = "distilbert-base-uncased"
bert_tokenizer = DistilBertTokenizer.from_pretrained(model_name)
bert_inputs = bert_tokenizer(sentence, return_tensors="pt", max_length = 128, truncation = True, 
                             add_special_tokens = False)

model_name = "distilgpt2"
gpt_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
gpt_inputs = gpt_tokenizer(sentence, return_tensors="pt", max_length = 128, truncation = True)

# Main

In [None]:
import TextVAE
importlib.reload(TextVAE)
from TextVAE import TextVAE

tvae = TextVAE()
tvae.train()
epochs = 50
optim = AdamW(tvae.parameters(), 5e-4)
scheduler = CosineAnnealingLR(optim, T_max = epochs, eta_min = 0)
for epoch in range(epochs):
    loss, kl_loss = tvae(bert_inputs['input_ids'], bert_inputs['attention_mask'], 
                gpt_inputs['input_ids'], gpt_inputs['attention_mask'])
    (loss + kl_loss).backward()
    optim.step()
    optim.zero_grad()
    scheduler.step()
    print(f"epoch: {epoch + 1}; loss: {loss.item()/128, kl_loss.item()/768/4};")

epoch: 1; loss: (3.6127283573150635, 0.005113375683625539);
epoch: 2; loss: (3.4085443019866943, 0.00807518387834231);
epoch: 3; loss: (3.3747751712799072, 0.00529614028831323);
epoch: 4; loss: (2.702524185180664, 0.005362549175818761);
epoch: 5; loss: (2.239224910736084, 0.00994647853076458);
epoch: 6; loss: (2.0090229511260986, 0.01038353517651558);
epoch: 7; loss: (1.7497045993804932, 0.006154864405592282);
epoch: 8; loss: (1.3401188850402832, 0.005929786711931229);
epoch: 9; loss: (1.1326369047164917, 0.0056842391689618426);
epoch: 10; loss: (0.9813932180404663, 0.005180810888608296);
epoch: 11; loss: (1.1136592626571655, 0.005124591911832492);
epoch: 12; loss: (0.7368621826171875, 0.003474247952302297);
epoch: 13; loss: (1.2699332237243652, 0.002689758315682411);
epoch: 14; loss: (0.6200105547904968, 0.002891318562130133);
epoch: 15; loss: (0.5271949768066406, 0.002945307952662309);
epoch: 16; loss: (0.5569228529930115, 0.0029155568530162177);
epoch: 17; loss: (0.4474020004272461,

In [143]:
tvae.eval()
bias = tvae(bert_inputs['input_ids'], bert_inputs['attention_mask'], do_sample = False)

gpt = GPT2LMHeadModel.from_pretrained('distilgpt2')
with torch.no_grad():
    gpt.transformer.h[-2].mlp.c_fc.bias += bias[0,0]
    gpt.transformer.h[-1].mlp.c_fc.bias += bias[0,1]

In [145]:
with torch.no_grad():
    tokens = gpt(gpt_inputs['input_ids'])[0][0].argmax(-1)
print(gpt_tokenizer.decode(tokens))
print((tokens[:-1] != gpt_inputs['input_ids'][0, 1:]).sum().item())

Genshin Impact: A Masterpiece in the World of Gaming,   Introduction by GHoYo in September 2010, Genshin Impact has taken the gaming world by storm. With its stunning open world, intricate lore, and stunningivating characters, this action role-playing game hasRPG) has garnered a massive global following. With article delves into what makes Genshin Impact a standout title in the crowded landscape of modern gaming. This  Masterast World Beautiful World  Released of Genshin Impact's most notable features is its stunning open world.  game takes place in the fantasy land of Teyvat, the
15


In [137]:
test_ids = gpt_tokenizer(prompt, return_tensors="pt").input_ids
if torch.cuda.is_available():
   test_ids = test_ids.to('cuda')

generated_ids = gpt.generate(
    test_ids,
    max_length = 128,
    num_return_sequences = 1,
    no_repeat_ngram_size = 5,
    num_beams = 5,
    eos_token_id = 50256,
    bad_words_ids = [[198]],
)

# 解码生成的文本
generated_text = gpt_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_text[0])
print(len(generated_ids[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Genshin Impact: A Masterpiece in the World of Gaming  Introduction  Released by miHoYo in September 2020, Genshin Impact has taken the gaming world by storm. With its expansive open world, intricate lore, and captivating characters, this action role-playing game (RPG) has garnered a massive global following. This article delves into what makes Genshin Impact a standout title in the crowded landscape of modern gaming.  A Vast and Beautiful World  One of Genshin Impact's most notable features is its stunning open world. The game takes place in the fantasy land of Teyvat,
128
