In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6, 7'

In [2]:
import pickle

import numpy as np
import torch as T

from omegaconf import OmegaConf
from taming.models.vqgan import VQModel

from tats.CONST import *
from tats import Net2NetTransformer
from tats.modules.gpt import sample_with_past
print('CONST:', SOS, SPAN, BOS, EOS, PAD, SEP)

import clip

from PIL import Image

CONST: 1024 1025 1026 1027 1028 1029


In [3]:
data = pickle.load(open('./_input/data.pkl', 'rb'))

In [4]:
cfg = OmegaConf.load('./_ckpt/yaml_taming_128.yaml')
VQ = VQModel(**cfg.model.params).eval().cuda()
_ = VQ.load_state_dict(T.load('./_ckpt/ckpt_taming_mugen_128.pt', map_location='cpu'))

Working with z of shape (1, 256, 8, 8) = 16384 dimensions.
loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


In [5]:
gpt = Net2NetTransformer._load_model_state(T.load('./_ckpt/ckpt_mmvg_mugen.pt', map_location='cpu')).eval().cuda()

LEN_VIDEO, VOC_COND = gpt.args.sequence_length, gpt.cond_stage_vocab_size
print('----- (LEN_VIDEO, VOC_COND):', LEN_VIDEO, VOC_COND, '-----')

----- (FIRST, COND, GPT): 16384 49408 65792 -----
----- (LEN_VIDEO, VOC_COND): 16 49408 -----


In [6]:
def img_post(dat):
    pix = (((dat.permute(0, 2, 3, 1)+1.0)/2.0).clamp(0, 1).cpu().numpy()*255.0).astype(np.uint8)
    return pix

def prepare(typ, item):
    if typ=='prediction' or typ=='rewind': z = item['img']
    elif typ=='infilling': z0, z1 = item['img0'], item['img1']
    
    if typ=='prediction': cz = [PAD]+z.flatten().tolist()+[SPAN]
    elif typ=='rewind': cz = [SPAN]+z.flatten().tolist()+[PAD]
    elif typ=='infilling': cz = [PAD]+z0.flatten().tolist()+[SPAN]+z1.flatten().tolist()+[PAD]
    cz += [PAD for _ in range(262-len(cz))]
    cz += [BOS]
    cz = [c+VOC_COND for c in cz]
    cx = clip.tokenize(item['ins'], context_length=100, truncate=True)[0].numpy().tolist()
    c = T.from_numpy(np.array(cx+cz, np.int64))
    
    return c

def post(d):
    try: p = T.where(d==EOS)[0][0].item()
    except: p = 64*LEN_VIDEO
    l = p//64
    d = d[:l*64].clip(0, 1023).view([l, 8, 8])
    return d

def run(c, step):
    with T.no_grad(): outs = sample_with_past(c.cuda(), gpt.transformer, steps=step, sample_logits=False, 
                                              temperature=1.0, top_k=10, top_p=0.98)
    frames = []
    for out in outs:
        z = post(out-VOC_COND)
        with T.no_grad(): out = VQ.decode(VQ.quantize.embedding(z.cuda()).permute(0, 3, 1, 2))
        pix = img_post(out)
        frames.append([Image.fromarray(f).convert('RGB') for f in pix])
    return frames

In [7]:
cs = []
for idx, item in enumerate(data['prediction']):
    img, ins = item['img'], item['ins']
    print('-----', idx, ins, '-----')
    c = prepare('prediction', {'img': img, 'ins': ins})
    cs.append(c.unsqueeze(0))
cs = T.cat(cs, dim=0)

frames = run(cs, (64*16+3)-1)
for idx, frame in enumerate(frames):
    frame[0].save('./_output/vp_%d.gif'%(idx), format='GIF', append_images=frame, 
                  duration=int(1000.0/5.0), save_all=True, loop=0, 
                  quality=100, sub_sampling=0)

----- 0 Mugen jumps up to the platform on the right and collects two coins. -----
----- 1 Mugen jumps onto the ladder and then down onto the level to collect a coin. -----
----- 2 Mugen heads left and jumps onto a face, crushing it, while collecting a coin. It then heads back right and jumps down to the ground-level. -----
----- 3 Mugen collects a coin, and then a mouse runs into it from behind. -----


In [8]:
cs = []
for idx, item in enumerate(data['rewind']):
    img, ins = item['img'], item['ins']
    print('-----', idx, ins, '-----')
    c = prepare('rewind', {'img': img, 'ins': ins})
    cs.append(c.unsqueeze(0))
cs = T.cat(cs, dim=0)

frames = run(cs, (64*16+3)-1)
for idx, frame in enumerate(frames):
    frame[0].save('./_output/vr_%d.gif'%(idx), format='GIF', append_images=frame, 
                  duration=int(1000.0/5.0), save_all=True, loop=0, 
                  quality=100, sub_sampling=0)

----- 0 Mugen jumps up and down a few times. -----
----- 1 Mugen runs right to left and it jump runs left to right and it collect coin and gem. -----
----- 2 Mugen walks to the right while on the ground level at a steady pace before jumping up to the platform. It collects a gem on this small platform and drops back down to the ground level. -----
----- 3 Mugen jumps onto a platform and moves from left to right, collects a coin, then jumps onto snail, squishing it. Mugen then passes under a bee and moves off the right edge of the platform, landing on the ground. -----


In [9]:
cs = []
for idx, item in enumerate(data['infilling']):
    img0, img1, ins = item['img0'], item['img1'], item['ins']
    print('-----', idx, ins, '-----')
    c = prepare('infilling', {'img0': img0, 'img1': img1, 'ins': ins})
    cs.append(c.unsqueeze(0))
cs = T.cat(cs, dim=0)

frames = run(cs, (64*16+3)-1)
for idx, frame in enumerate(frames):
    frame[0].save('./_output/vi_%d.gif'%(idx), format='GIF', append_images=frame, 
                  duration=int(1000.0/5.0), save_all=True, loop=0, 
                  quality=100, sub_sampling=0)

----- 0 Mugen moves left. It jumps through a narrow gap up to a platform and moves right to collect coins. -----
----- 1 Mugen runs from left to right. it jumps over a gear, then make a run to collect a coin. -----
----- 2 Mugen bounced to the left two times, then bounced one time back to the right once it saw the wall. -----
----- 3 Mugen hops up and hits it head. It then leaps up on the platform and jumps to smash the worm after collecting a coin and then turns back left. -----
