In [15]:
from mmengine import Config
import torch
from matplotlib import pyplot as plt
import numpy as np
import os
import pickle
from get_model import get_model

In [17]:
ckp_root = '07-13-24/modern-model'
config = Config.fromfile(os.path.join(ckp_root, 'config.py'))
ckp = torch.load(os.path.join(ckp_root, 'last.ckpt'), map_location='cpu')

model = get_model(config.model_type, config.model_config)
sd = ckp['state_dict'] 
sd = {k[6:]: v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)

with open('DATA/word2idx.pkl', 'rb') as f:
    word2idx = pickle.load(f)

with open('DATA/idx2word.pkl', 'rb') as f:
    idx2word = pickle.load(f)
model.word2idx = word2idx
model.idx2word = idx2word

In [18]:
def sample(model, 
            str_strings: str,
            device = 'cpu') -> str:
    sos_token_idx = model.word2idx['<SOS>']
    eos_token_idx = model.word2idx['<EOS>']
    decode = lambda x: ''.join([model.idx2word[i] for i in x])
    encode = lambda x: [model.word2idx[w] for w in x]
    
    samples = []
    print('sampling...')
    inp = torch.tensor([sos_token_idx] + encode(str_strings)).long().to(device)
    inp = inp.unsqueeze(0)
    for _ in range(50):
        out = model(inp)
        assert out.shape[0] == 1
        next_token_p = torch.nn.functional.softmax(out[0, -1:, :], dim=-1)
        
        # next_token_idx = torch.argmax(next_token_p, dim=1, keepdim=True)
        # do random sample
        next_token_idx = torch.multinomial(next_token_p, 1)
        inp = torch.cat([inp, next_token_idx], dim=1)
        if next_token_idx.item() == eos_token_idx:
            break
    dec_text = decode(inp.cpu()[0].tolist())
    samples.append(dec_text)
    return samples

In [19]:
sample(model, str_strings='去年今日', device='cpu')

sampling...


['<SOS>去年今日此花前|只欠檀羊一笑篇|春色至方无此景|灵根向得最侔旋<EOS>']