In [1]:
# ==== import from package ==== #
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import rearrange, repeat, reduce
from collections import defaultdict
# ==== import from this folder ==== #
from model_VQVAE import VQVAE
from discriminator import NLayerDiscriminator, weights_init
from dataset import get_dataloader
from util import reset_dir, weight_scheduler, compact_large_image
from logger import Logger
DEVICE = torch.device("cuda")
print("DEVICE:", DEVICE)

DEVICE: cuda


In [2]:
batch_size = 32
dataloader = get_dataloader(mode='train_and_validate', batch_size=batch_size)

called train_and_validate 32 None


100%|██████████| 9664/9664 [00:11<00:00, 811.24it/s]
100%|██████████| 1120/1120 [00:01<00:00, 772.98it/s]


In [3]:
net = VQVAE().to(DEVICE)
net = torch.load(f'model_ckpt/VQVAE/epoch_AE_50.pt')
net.eval()
print()




In [4]:
def get_ind_data(net, dataloader):
    indices = []
    for now_step, batch_data in tqdm(enumerate(dataloader), total=len(dataloader)):
        raw_img, seg_img, brain_idx, z_idx = [
            data.to(DEVICE) for data in batch_data]
        with torch.no_grad():
            batch_size = raw_img.shape[0]
            latent = net.encode(raw_img)
            quant, diff_loss, (_, _, ind) = net.quantize(latent)
            ind = rearrange(ind, '(b c h w) -> b c h w', b=batch_size,
                            h=net.z_shape[0], w=net.z_shape[1])
            indices.append(ind.detach())
    indices = torch.cat(indices)
    indices = reduce(indices, 'b c h w -> b h w', 'min')
    return indices

In [5]:
with torch.no_grad():
    ind = get_ind_data(net, dataloader).detach()

100%|██████████| 337/337 [01:16<00:00,  4.38it/s]


In [6]:
print(ind.shape)
print(torch.max(ind), torch.min(ind))
for i in range(16):
    for j in range(16):
        x = ind[:, i, j]
        print('%4d' % len(torch.unique(x)), end='')
    print()


torch.Size([10784, 16, 16])
tensor(252, device='cuda:0') tensor(17, device='cuda:0')
   8   6   4   6   6   6   6   6   6   7  10  11   6   6   7   6
   4   3   6   7   7   7   7   7   7   6   6   5   5   8   6   6
   5   6   8   8   7   7   6   7   7   6   8   7   6   9   6   6
   7   8   7   7   7   8   7   6   7   7   7   7   8   8   7   6
   8   8   7   6   7   8   8   6   6   8   8   8   9   7   6   7
   9   9   6   8   8   6   6   6   6   8   8   7   6   6   7   8
   8   7   7   6   7   6   7   6   6   6   6   6   8   8   7   8
   7   5   7   8   6   7   7   6   6   6   6   6   6   7   7   7
   7   6   7   6   6   6   6   6   6   6   6   7   7   6   6   8
   6   5   8   6   6   6   6   6   6   5   6   7   8   8   7   7
   7   5   7   8   6   6   6   7   6   7   6   6   7   8   6   9
   9   7   8   8   8   6   8   6   7   7   7   7   7   6   7   9
   8   9   8   7   7   7   6   6   7   6   8   7   9   7   9   8
   8   7   8   8   7   8   7   7   7   8   8   8   9   8  10   7
   8 

In [7]:
# ind = rearrange(ind[:27], 'b w h -> b w h ')
print(ind.shape)
# ind shape: [batch_size, w*h]
emb = net.quantize.embedding(ind)
# emb shape: [batch_size, w*h, emb_len]
print(emb.shape)

ind = rearrange(ind, 'N W H -> N (W H)')
emb = rearrange(emb, 'N W H E -> N ( W H ) E')

torch.Size([10784, 16, 16])
torch.Size([10784, 16, 16, 8])


In [8]:
import math
class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 初始化Shape为(max_len, d_model)的PE (positional encoding)
        pe = torch.zeros(max_len, d_model)
        # 初始化一个tensor [[0, 1, 2, 3, ...]]
        position = torch.arange(0, max_len).unsqueeze(1)
        # 这里就是sin和cos括号中的内容，通过e和ln进行了变换
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        # 计算PE(pos, 2i)
        pe[:, 0::2] = torch.sin(position * div_term)
        # 计算PE(pos, 2i+1)
        pe[:, 1::2] = torch.cos(position * div_term)
        # 为了方便计算，在最外面在unsqueeze出一个batch
        pe = pe.unsqueeze(0)
        # 如果一个参数不参与梯度下降，但又希望保存model的时候将其保存下来
        # 这个时候就可以用register_buffer
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        x 为embedding后的inputs，例如(1,7, 128)，batch size为1,7个单词，单词维度为128
        """
        # 将x和positional encoding相加。
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [96]:
class PixelCNN(nn.Module):

    def __init__(self, vae):
        super(PixelCNN, self).__init__()

        self.embed_dim = vae.quantize.e_dim
        self.n_embed = vae.quantize.n_e
        self.z_shape = vae.z_shape
        self.transformer = nn.Transformer(d_model=self.embed_dim,
                                     nhead=8,
                                     num_encoder_layers=6,
                                     num_decoder_layers=6,
                                     dim_feedforward=256,
                                     dropout=0.1,
                                     batch_first=True).cuda()
        
        self.positional_encoding = PositionalEncoding(d_model=self.embed_dim, max_len=self.z_shape[0] * self.z_shape[1], dropout=0)

        self.predictor = nn.Linear(self.embed_dim, self.n_embed)

    def generate_random_mask(self, batch_size, S, low=None, high=None):
        if low is None:
            low = 0
        if high is None:
            high = S+1
        mask = torch.zeros([batch_size, S], device=DEVICE, dtype=torch.long)
        for i in range(batch_size):
            mask_num = torch.randint(low=low, high=high, size=(1,),  device=DEVICE)
            perm = torch.randperm(S, device=DEVICE)[None, :]
            mask[i, :] = (perm < mask_num).long()
        return mask

    def forward(self, src, low=None, high=None):
        N, S, E = src.shape
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(S).cuda()
        from einops import repeat, rearrange

        teacher_forcing = False
        tgt = src
        src_mask = self.generate_random_mask(batch_size=N, S=S, low=low, high=high)
        src = src * (1 - src_mask[:, :, None].long())
        src = self.positional_encoding(src)
        tgt = self.positional_encoding(tgt)
        if teacher_forcing:
            out = self.transformer(src, tgt, tgt_mask=tgt_mask)
            out = self.predictor(out)
            return out
        else:
            src_enc = self.transformer.encoder(src)
            out_embs = torch.zeros([N, S, E], device=DEVICE)
            out_inds = []
            out_logits = []
            for i in range(S):
                # print(cur[0, :,0])
                print(i, out_embs.shape, src_enc.shape)
                mask = repeat(tgt_mask[i], 'S -> N S', N = N)

                out = self.transformer.decoder(out_embs, memory=src_enc, tgt_mask=mask)
                logits = self.predictor(out[:, i, :])
                out_logits.append(logits[:, None, :])
                out_ind = torch.argmax(logits.detach(), dim=1)
                out_inds.append(out_ind[:, None])
                out_emb = net.quantize.embedding(out_ind).detach()
                out_embs[:, i, :] = out_emb
            return torch.cat(out_logits, dim=1)


In [97]:
pixelCNN = PixelCNN(net).to(DEVICE)
from torchinfo import summary
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(pixelCNN.parameters(), lr=2e-4)

In [98]:
from torch.utils.data import TensorDataset, DataLoader

print(ind.shape, emb.shape)
dataset = TensorDataset(emb, ind)
dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)

torch.Size([10784, 256]) torch.Size([10784, 256, 8])


In [99]:
def train(i):

    cut = min(i // 2, 15)
    low, high = cut*16, (cut+1) * 16
    print(low, high)
    total_loss = 0
    for now_step, batch_data in tqdm(enumerate(dataloader), total=len(dataloader)):
        emb, ind = [data.to(DEVICE) for data in batch_data]
        optimizer.zero_grad()

        out = pixelCNN(emb.detach(), low=low, high=high)
        loss = criteria(out.contiguous().view(-1, out.size(-1)), ind.contiguous().view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

In [100]:
for i in range(50):
    print(f"epoch {i}: loss: {train(i)}")


0 16


  0%|          | 0/10784 [00:00<?, ?it/s]

0 torch.Size([1, 256, 8]) torch.Size([1, 256, 8])





RuntimeError: The shape of the 2D attn_mask is torch.Size([256, 1]), but should be (256, 256).

In [None]:

binc = torch.bincount(ind.view([-1]))
for i, c in enumerate(binc):
    print(i, c.item(), end='\t')
    if i % 16 == 15:
        print()

0 0	1 0	2 0	3 0	4 0	5 0	6 0	7 0	8 0	9 0	10 0	11 0	12 0	13 0	14 0	15 0	
16 0	17 72	18 0	19 0	20 0	21 0	22 0	23 0	24 47	25 0	26 0	27 0	28 0	29 0	30 0	31 0	
32 0	33 0	34 0	35 0	36 0	37 0	38 0	39 0	40 0	41 0	42 0	43 0	44 0	45 0	46 0	47 0	
48 0	49 0	50 0	51 0	52 0	53 0	54 0	55 0	56 0	57 0	58 0	59 0	60 2885	61 0	62 407961	63 0	
64 0	65 0	66 0	67 0	68 0	69 0	70 0	71 0	72 0	73 0	74 0	75 0	76 659550	77 0	78 0	79 0	
80 0	81 0	82 0	83 0	84 0	85 0	86 0	87 0	88 0	89 0	90 0	91 0	92 0	93 0	94 0	95 0	
96 0	97 0	98 0	99 0	100 0	101 0	102 0	103 0	104 0	105 0	106 0	107 0	108 0	109 0	110 0	111 0	
112 0	113 0	114 0	115 0	116 0	117 0	118 0	119 0	120 0	121 0	122 0	123 0	124 0	125 5391	126 0	127 0	
128 0	129 0	130 0	131 0	132 0	133 0	134 0	135 0	136 0	137 0	138 468018	139 0	140 0	141 0	142 0	143 0	
144 0	145 0	146 0	147 0	148 0	149 0	150 0	151 0	152 0	153 0	154 0	155 0	156 0	157 0	158 0	159 0	
160 0	161 0	162 0	163 25	164 0	165 0	166 0	167 0	168 0	169 0	170 0	171 0	172 0	173 0	174 0	175 0	
176 0	177 0	178 0	1

In [None]:
with torch.no_grad():
    
    # emb = rearrange(emb, 'N W H E -> N ( W H ) E')
    # cur = torch.zeros(32, 16 * 16, 8).to(DEVICE)
    # cur = net.quantize.embedding(ind[:32])
    # cur[:, 32:, :] = 0
    # ans = net.quantize.embedding(ind[:32])
    # ans = pixelCNN.positional_encoding(ans)
        
    # print(ind[0])
    N, S, E = 32, 16*16, 8
    tgt_mask = nn.Transformer.generate_square_subsequent_mask(S).cuda()
    from einops import repeat, rearrange
    
    cur = torch.zeros([N, S, E], device=DEVICE)
    out_embs = torch.zeros([N, S, E], device=DEVICE)
    out_inds = []
    
    for i in range(256):
        # print(cur[0, :,0])
        pos_cur = pixelCNN.positional_encoding(cur)
        out = pixelCNN.transformer(pos_cur, out_embs, tgt_mask=tgt_mask)
        out_ind = torch.argmax(pixelCNN.predictor(out[:, i, :]), dim=1)
        out_inds.append(out_ind[:, None])
        out_emb = net.quantize.embedding(out_ind)
        cur[:, i, :] = out_emb
        out_embs[:, i, :] = out_emb
        # print(out_emb[:, None, :].shape)
        # cur [:, i, :] = 
    # nxt_ind = torch.argmax(pixelCNN(net.quantize.embedding(ind[:32])), dim=2)
    out_inds = torch.cat(out_inds, dim=1)
    print(out_inds)
    print(out_inds.shape)
    z_q = net.quantize.embedding(out_inds)
    z_q = rearrange(z_q, 'b (w h) c -> b c w h', w = 16)
    sample = net.decode(z_q).cpu()

    print(sample.shape)
    from util import compact_large_image
    imgs = compact_large_image(sample, HZ=4, WZ=8)
    for idx in range(imgs.shape[0]):
        plt.imsave('test.png', imgs[0] * 0.5 + 0.5, cmap='gray')


tensor([[ 62,  76,  62,  ..., 138,  76,  76],
        [ 62,  76,  76,  ..., 138,  76, 138],
        [ 62,  62,  76,  ...,  76, 138, 138],
        ...,
        [ 76,  76,  76,  ...,  76,  76, 138],
        [ 62,  76,  76,  ...,  76, 227,  76],
        [ 76,  76,  76,  ...,  76,  62, 138]], device='cuda:0')
torch.Size([32, 256])
torch.Size([32, 1, 256, 256])


In [None]:
print(ind.shape)
map_id = { x.item(): i for i, x in enumerate(torch.unique(ind.view([-1])))}
print(map_id)
remap_id = { i: x for x, i in map_id.items()}
ind2 = ind.cpu().numpy()
for i, x in map_id.items():
    ind2[ind2 == i] = x
print(ind2[:, :3])

torch.Size([10784, 256])
{17: 0, 24: 1, 60: 2, 62: 3, 76: 4, 125: 5, 138: 6, 163: 7, 194: 8, 227: 9, 251: 10, 252: 11}
[[8 6 6]
 [8 9 4]
 [8 9 4]
 ...
 [8 9 4]
 [8 9 4]
 [8 4 4]]
