# RSSMを小説を使って学習

In [None]:
!nvidia-smi

Sat Mar 26 09:47:15 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    26W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
#@title clone repository
!git clone https://github.com/tortoise10101/LWM.git
%cd LWM

fatal: destination path 'LWM' already exists and is not an empty directory.
/content/LWM


In [None]:
#@title load model and book-corpus
# 5分くらいかかる
# http responseがたまに，403になるので，そのときはもう一回試してください
!./loader.sh

--2022-03-26 09:48:34--  https://docs.google.com/uc?export=download&confirm=t&id=1ibyDaM9OkLCW1GcYnDaoKmoOTwcluKeT
Resolving docs.google.com (docs.google.com)... 172.217.204.102, 172.217.204.139, 172.217.204.100, ...
Connecting to docs.google.com (docs.google.com)|172.217.204.102|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://doc-0g-7s-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/d394u00ci5knai572mhibp3oesniohvc/1648288050000/06512783872423863144/*/1ibyDaM9OkLCW1GcYnDaoKmoOTwcluKeT?e=download [following]
--2022-03-26 09:48:34--  https://doc-0g-7s-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/d394u00ci5knai572mhibp3oesniohvc/1648288050000/06512783872423863144/*/1ibyDaM9OkLCW1GcYnDaoKmoOTwcluKeT?e=download
Resolving doc-0g-7s-docs.googleusercontent.com (doc-0g-7s-docs.googleusercontent.com)... 142.250.98.132, 2607:f8b0:400c:c1a::84
Connecting to doc-0g-7s-docs.googleusercontent.com (doc-0

In [None]:
import torch
from torch.distributions.kl import kl_divergence
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_

from worldmodel import RSSM, ReplayBuffer
from seq2vec import VAE

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 文章をベクトル化するやつ
seq2vec = VAE.load().to(device)

latent_dim = 128
state_dim = 30 #@param {type:"integer"}
rnn_hidden_dim = 200 #@param {type: "integer"}
# RSSM
rssm = RSSM(
    state_dim=state_dim,
    rnn_hidden_dim=rnn_hidden_dim)

model_lr = 6e-4  #@param
eps = 1e-4 #@param
model_params = (
    list(rssm.transition.parameters()) +
    list(rssm.observation.parameters()))
model_optimizer = torch.optim.Adam(model_params, lr=model_lr, eps=eps)

capacity = 1000000 #@param {type: "integer"}
replay_buffer = ReplayBuffer(capacity=capacity, observation_shape=latent_dim)

batch_size = 128 #@param {type: "integer"}
chunk_length = 20 #@param {type: "integer"}

free_nats = 3
clip_grad_norm = 100

In [None]:
#@title train func
def train():
    global seq2vec, rssm, model_optimizer, replay_buffer
    for t in range(1000):
        observations, _ = replay_buffer.sample(batch_size, chunk_length)
        observations = torch.as_tensor(observations, device=device)
        observations = observations.view(chunk_length, batch_size, -1)
        # observations = observations.transpose(3, 4).transpose(2, 3)
        # observations = observations.transpose(0, 1)

        states = torch.zeros(
            chunk_length, batch_size, state_dim, device=device)
        rnn_hiddens = torch.zeros(
            chunk_length, batch_size, rnn_hidden_dim, device=device)

        state = torch.zeros(batch_size, state_dim, device=device)
        rnn_hidden = torch.zeros(batch_size, rnn_hidden_dim, device=device)

        kl_loss = 0
        for l in range(chunk_length-1):
            next_state_prior, next_state_posterior, rnn_hidden = \
                rssm.transition(state, rnn_hidden, observations[l+1])
            state = next_state_posterior.rsample()
            states[l+1] = state
            rnn_hiddens[l+1] = rnn_hidden
            kl = kl_divergence(
                next_state_prior, next_state_posterior).sum(dim=1)
            kl_loss += kl.clamp(min=free_nats).mean()
            # kl_loss += kl.mean()
        kl_loss /= (chunk_length-1)

        states = states[1:]
        rnn_hiddens = rnn_hiddens[1:]

        flatten_states = states.view(-1, state_dim)
        flatten_rnn_hiddens = rnn_hiddens.view(-1, rnn_hidden_dim)
        recon_observations = \
            rssm.observation(
                flatten_states,
                flatten_rnn_hiddens
                ).view(chunk_length-1, batch_size, latent_dim)

        obs_loss = \
            0.5 * F.mse_loss(
                recon_observations.float(),
                observations[1:].float(), reduction='none').mean([0, 1]).sum()

        model_loss = kl_loss + obs_loss
        model_optimizer.zero_grad()
        model_loss.backward()
        clip_grad_norm_(model_params, clip_grad_norm)
        model_optimizer.step()

        if t % 10 == 0:
            print('update_step: %3d model loss: %.5f, kl_loss: %.5f, obs_loss: %.5f' \
                % (t, model_loss.item(), kl_loss.item(), obs_loss.item()))

In [None]:
# book corpus からデータとってbufferに追加
def prepare_buffer():
    global replay_buffer
    with open('dataset/books_large_p2.txt', 'r') as f:
        r = f.read().split('\n')
        bs = 32*2*2
        for i in range(0, len(r)-bs, bs):
            print("%.3f %%, %.3f %%" % (100*i/len(r), 100*i/replay_buffer.capacity))
            if replay_buffer.is_filled:
                break
            if i/len(r) > 0.1:
                break
            _, _, z, _ = seq2vec.forward(
                seq2vec.tokenize(r[i:i+bs]).to(device))
            z = z.detach().cpu().numpy()
            # TODO add split
            replay_buffer.push_batch(z, torch.tensor([False]*bs).view(-1, 1))

def load_buffer():
    global replay_buffer
    import pickle
    with open('dataset/books_large_p2.pickle', 'rb') as f:
        replay_buffer = pickle.load(f)


def dump_buffer():
    import pickle
    with open('dataset/books_large_p2.pickle', 'wb') as f:
        pickle.dump(replay_buffer, f)
      
def imagine(text, nhorizon=10):
    _, _, z, _ = seq2vec.forward(seq2vec.tokenize(text).to(device))
    rnn_hidden = torch.zeros(1, rnn_hidden_dim, device=device)
    state = rssm.transition.posterior(rnn_hidden, z).sample()

    imagined_states = [None] * nhorizon
    imagined_rnn_hiddens = [None] * nhorizon
    for i in range(nhorizon):
        state_prior, rnn_hidden = \
            rssm.transition.prior(rssm.transition.recurrent(state, rnn_hidden))

        state = state_prior.sample()
        imagined_states[i] = state
        imagined_rnn_hiddens[i] = rnn_hidden

    return imagined_states, imagined_rnn_hiddens


def decode(imagined_states, imagined_rnn_hiddens):
    seqt = []
    for state, rnn_hidden in zip(imagined_states, imagined_rnn_hiddens):
        obs = rssm.observation(state, rnn_hidden)
        g = seq2vec.generate(obs, max_len=20, alg='greedy')
        seqt.append(seq2vec.detokenize(g))
    return seqt


def generate(text):
    imagined_states, imagined_rnn_hiddens = \
        imagine(text)
    seqt = decode(imagined_states, imagined_rnn_hiddens)
    seq = ""
    for s in seqt:
        for c in s[0]:
            seq += c
            seq += " "
        seq += ".\n"

    print(seq, sep="\n")
    # return seq

In [None]:
#@title 学習データのロード

# replay_bufferのcapacity分しかロードしてないです，
# replay_bufferも更新はしてない（今は）
# 10分くらい
prepare_buffer()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
1.059 %, 36.019 %
1.060 %, 36.032 %
1.060 %, 36.045 %
1.060 %, 36.058 %
1.061 %, 36.070 %
1.061 %, 36.083 %
1.062 %, 36.096 %
1.062 %, 36.109 %
1.062 %, 36.122 %
1.063 %, 36.134 %
1.063 %, 36.147 %
1.063 %, 36.160 %
1.064 %, 36.173 %
1.064 %, 36.186 %
1.065 %, 36.198 %
1.065 %, 36.211 %
1.065 %, 36.224 %
1.066 %, 36.237 %
1.066 %, 36.250 %
1.066 %, 36.262 %
1.067 %, 36.275 %
1.067 %, 36.288 %
1.068 %, 36.301 %
1.068 %, 36.314 %
1.068 %, 36.326 %
1.069 %, 36.339 %
1.069 %, 36.352 %
1.069 %, 36.365 %
1.070 %, 36.378 %
1.070 %, 36.390 %
1.071 %, 36.403 %
1.071 %, 36.416 %
1.071 %, 36.429 %
1.072 %, 36.442 %
1.072 %, 36.454 %
1.072 %, 36.467 %
1.073 %, 36.480 %
1.073 %, 36.493 %
1.074 %, 36.506 %
1.074 %, 36.518 %
1.074 %, 36.531 %
1.075 %, 36.544 %
1.075 %, 36.557 %
1.075 %, 36.570 %
1.076 %, 36.582 %
1.076 %, 36.595 %
1.077 %, 36.608 %
1.077 %, 36.621 %
1.077 %, 36.634 %
1.078 %, 36.646 %
1.078 %, 36.659 %
1.078 %, 36.672 %

In [None]:
#@title 学習

# 何回か実行してください
# 学習回数とか適当です（lossとかもまだ，みてない）
train()

update_step:   0 model loss: 50.82182, kl_loss: 3.56583, obs_loss: 47.25599
update_step:  10 model loss: 50.88021, kl_loss: 3.55641, obs_loss: 47.32380
update_step:  20 model loss: 50.63042, kl_loss: 3.56966, obs_loss: 47.06076
update_step:  30 model loss: 50.54004, kl_loss: 3.57856, obs_loss: 46.96147
update_step:  40 model loss: 51.29530, kl_loss: 3.55420, obs_loss: 47.74110
update_step:  50 model loss: 50.80862, kl_loss: 3.55933, obs_loss: 47.24929
update_step:  60 model loss: 51.02370, kl_loss: 3.59112, obs_loss: 47.43258
update_step:  70 model loss: 50.88790, kl_loss: 3.58386, obs_loss: 47.30404
update_step:  80 model loss: 50.59313, kl_loss: 3.54304, obs_loss: 47.05009
update_step:  90 model loss: 50.48477, kl_loss: 3.56327, obs_loss: 46.92150
update_step: 100 model loss: 50.46212, kl_loss: 3.56785, obs_loss: 46.89427
update_step: 110 model loss: 50.98101, kl_loss: 3.55609, obs_loss: 47.42493
update_step: 120 model loss: 51.02509, kl_loss: 3.58149, obs_loss: 47.44360
update_step:

In [None]:
#@title 生成
generate("This is a pen. This is a pineapple.")

`` by parking when i worked in that dried comments in their details mouth .
she goes for his bathroom to see the blood spray named telephone carpeting <unk> .
who meant about them are the wash <unk> without jay eye telephone mouth .
that mattress like her position -- the photo speakers requesting disney balance that .
he , head <unk> employees off the <unk> talking richie moldy versions of me ) .
woman , someone had like any patron with what speakers threw of her talent .
basically it or this paper by mc their direction covering jay in profit .
john was using the website and try an ombre honest jacket painted off how painful .
she asks in he just on navy that an error like tire above shit .
`` everyone is called '' was turning of their information selling professionalism prior : .

