In [1]:
# ==== import from package ==== #
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import rearrange, repeat, reduce
from collections import defaultdict
from torch.utils.data import TensorDataset, DataLoader
# ==== import from this folder ==== #
from model_VAE import VQVAE
from model_discriminator import NLayerDiscriminator, weights_init
from dataset import get_dataloader
from util import reset_dir, weight_scheduler, compact_large_image, sinusoidal_embedding
from logger import Logger
DEVICE = torch.device("cuda")
print("DEVICE:", DEVICE)

DEVICE: cuda


In [2]:
mode = 'VQVAE'
load_epoch = 45

# Get latent set
latent_set = torch.load(f'collected_latents/{mode}_{load_epoch}.pt')
vae = torch.load(f'model_ckpt/{mode}/epoch_AE_{load_epoch}.pt')
vae.eval()

# Try to normalize latents with conditions
with torch.no_grad():
    latents, z_indices = latent_set.tensors[0], latent_set.tensors[1]
    quant, diff_loss, ind = vae.quantize(latents)
    ind = rearrange(ind, '(b 1 h w) -> b (h w)', b=len(latents),
                            h=vae.z_shape[0], w=vae.z_shape[1])
    dataset = TensorDataset(ind, z_indices)
    dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
    print(ind.shape)

torch.Size([11328, 256])


In [3]:
class PixelCNN(nn.Module):

    def __init__(self, vae):
        super(PixelCNN, self).__init__()

        self.embed_dim = 32
        self.n_embed = vae.quantize.n_e
        self.z_shape = vae.z_shape

        self.embed = nn.Embedding(vae.quantize.n_e, self.embed_dim)
        decoder_layer = nn.TransformerDecoderLayer(d_model=self.embed_dim, nhead=2)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)
        self.positional_encoding = sinusoidal_embedding(self.z_shape[0] * self.z_shape[1], self.embed_dim)
        self.predictor = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.SiLU(),
            nn.Linear(self.embed_dim, self.n_embed),
        )
        self.cond_mlp = nn.Sequential(
            nn.Embedding(32, self.embed_dim),
            nn.SiLU(),
            nn.Linear(self.embed_dim, self.embed_dim),
        )

    def forward(self, dst, z_idx):
        N, S = dst.shape
        
        h = self.embed(dst) + self.positional_encoding[None, :, :].to(dst.device)
        memory = self.cond_mlp(z_idx)
        mask = nn.Transformer.generate_square_subsequent_mask(S).to(DEVICE)
        memory = repeat(memory, 'N E -> S N E', S=S)
        # Add gaussian noise
        memory = memory + torch.randn_like(memory)
        h = rearrange(h, 'N S E -> S N E')
        h = self.decoder(h, memory, tgt_mask=mask)
        h = self.predictor(h)
        h = rearrange(h, 'S N E-> N S E')
        return h
    
    def sample(self, batch_size):
        N, S, E = batch_size, 16*16, self.embed_dim
        T = 32
        # 32, E -> 32 * B * S, E
        memory = torch.randn([N, T, E]).to(DEVICE)
        z_idx = torch.arange(0, 32, step=1, device=DEVICE)
        cond = self.cond_mlp(z_idx)
        memory = memory + cond[None, :, :]
        memory = repeat(memory, 'N T E -> S (N T) E', S=S)

        pos_enc = self.positional_encoding[None, :, :].to(DEVICE)
        mask = nn.Transformer.generate_square_subsequent_mask(S).to(DEVICE)

        gen_ind = torch.zeros(S, N * T).long().to(DEVICE)
        for i in tqdm(range(S), total=S):
            rev_ind = rearrange(gen_ind, 'S (N T) -> (N T) S', N=N, T=T)
            h = self.embed(rev_ind) + pos_enc
            h = rearrange(h, '(N T) S E -> S (N T) E',N=N, T=T)
            h = self.decoder(h, memory, tgt_mask=mask)
            h = self.predictor(h)
            h = rearrange(h, 'S (N T) E-> (N T) S E', N=N, T=T)
            from torch.utils.data import WeightedRandomSampler
            h = h[:, i, :].detach()
            cur_gen_ind = torch.zeros([N * T, ]).long().to(DEVICE)
            for idx in range(N * T):
                weight = h[idx].detach().cpu()
                weight = torch.nn.functional.softmax(weight+1e-6, dim=0)
                sample = WeightedRandomSampler(weight, num_samples=1)
                sample = list(sample)[0]
                cur_gen_ind[idx] = sample

            # h = torch.argmax(h, dim=1)
            gen_ind[i] = cur_gen_ind
        return rearrange(gen_ind, 'S (N T) -> N T S', N=N, T=T)

In [4]:
pixelCNN = PixelCNN(vae).to(DEVICE)
from torchinfo import summary
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(pixelCNN.parameters(), lr=2e-4)

In [5]:
def train(i):
    total_loss = 0
    for now_step, batch_data in tqdm(enumerate(dataloader), total=len(dataloader)):
        ind, z_idx = [data.to(DEVICE) for data in batch_data]

        optimizer.zero_grad()

        out = pixelCNN(ind.detach(), z_idx.detach())
        out = rearrange(out, "N S E -> (N S) E")
        ind = rearrange(ind, "N S -> (N S)")
        loss = criteria(out, ind)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)


In [6]:
for i in range(5):
    print(f"epoch {i}: loss: {train(i)}")
    torch.save(pixelCNN, f'model_ckpt/visualTransformer/model_{i}.pt')

100%|██████████| 2832/2832 [01:07<00:00, 42.04it/s]


epoch 0: loss: 0.3086366430558014


100%|██████████| 2832/2832 [01:11<00:00, 39.43it/s]


epoch 1: loss: 0.0006204869874112527


100%|██████████| 2832/2832 [01:15<00:00, 37.37it/s]


epoch 2: loss: 0.0001256646134663376


100%|██████████| 2832/2832 [01:13<00:00, 38.36it/s]


epoch 3: loss: 4.674267041783593e-05


100%|██████████| 2832/2832 [01:12<00:00, 39.28it/s]

epoch 4: loss: 2.358297165013106e-05





In [9]:
with torch.no_grad():
    from einops import repeat, rearrange
    for sample_idx in range(6):
        out_inds = pixelCNN.sample(1)[0]
        z_q = vae.quantize.embedding(out_inds)
        z_q = rearrange(z_q, 'b (h w) c -> b c h w', w = 16)
        cond = torch.arange(0, 32, 1).long().to(DEVICE)
        sample = vae.decode(z_q, cond).cpu()
        from util import compact_large_image
        imgs = compact_large_image(sample, HZ=4, WZ=8)
        for idx in range(imgs.shape[0]):
            plt.imsave(f'visualize/Transformer_vis/{sample_idx}.png', imgs[idx] * 0.5 + 0.5, cmap='gray')


100%|██████████| 256/256 [00:10<00:00, 23.86it/s]
100%|██████████| 256/256 [00:09<00:00, 25.67it/s]
100%|██████████| 256/256 [00:11<00:00, 22.51it/s]
100%|██████████| 256/256 [00:10<00:00, 25.24it/s]
100%|██████████| 256/256 [00:10<00:00, 24.99it/s]
100%|██████████| 256/256 [00:09<00:00, 25.97it/s]
