In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange, repeat, pack, unpack

import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader
import torch.optim as optim

import lovely_tensors as lt
lt.monkey_patch()

In [2]:
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import make_grid

In [3]:
training_data = datasets.CIFAR10(root="/home/aiteam/tykim/generative_model/data", 
                                 train=True, 
                                 download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

validation_data = datasets.CIFAR10(root="/home/aiteam/tykim/generative_model/data",
                                   train=False, 
                                   download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /home/aiteam/tykim/generative_model/data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /home/aiteam/tykim/generative_model/data/cifar-10-python.tar.gz to /home/aiteam/tykim/generative_model/data
Files already downloaded and verified


In [8]:
training_data.data.shape

(50000, 32, 32, 3)

In [5]:
np.var(training_data.data / 255.0)

0.06328692405746414

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super().__init__()
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # B, C, H, W => B, H, W, C
        inputs = inputs.permute(0, 2, 3, 1).contiguos()
        input_shape = inputs.shape

        #  B, H, W, C => BHW, C
        flat_input = inputs.view(-1, self._embedding_dim)
        # BHW, C
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) +
                     torch.sum(self._embedding.weight**2, dim=1)
                     -2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        # BHW, 1
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        

In [None]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, commitment_cost, decay, epsilon=1e-5):
        

In [11]:
torch.argmin(torach.tensor([[1,2,3,4,],[5,4,3,2]]), dim=1)

tensor([0, 3])

In [13]:
src = torch.arange(1, 11).reshape((2,5))
src

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])

In [14]:
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)

tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

In [None]:
class VectorQuantize(nn.Module):
    def __init__(self, dim, codebook_size, heads=1):
        super().__init__()
        self.dim = dim

        codebook_dim = default(codebook_dim, dim)
        self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
        self.project_out = nn.Linear(codebook_size)


    def forward(self, x):
        only_one = x.ndim == 2
        if only_one:
            x = rearrange(x, 'b d -> b 1 d')
            

In [22]:
dim = 3
codebook_input_dim = 10
requires_projection = False

In [None]:
class EuclideanCodebook(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        num_codebooks = 1,
        kmeans_init = False,
        kmeans_iters = 10,
        sync_kmeans = True,
        decay = 0.8,
        eps = 1e-5,
        threshold_ema_dead_code = 2,
        use_ddp = False,
        learnable_codebook = False,
        sample_codebook_temp = 0
    ):
        super().__init__()
        self.decay = decay
        init_fn = uniform_init if not kmeans_init else torch.zeros
        embed = init_fn(num_codebooks, codebook_size, dim)

        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks

        self.kmeans_iters = kmeans_iters
        self.eps = eps
        self.threshold_ema_dead_code = threshold_ema_dead_code
        self.sample_codebook_temp = sample_codebook_temp

        assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'

        self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
        self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
        self.all_reduce_fn = distributed.all_reduce if use_ddp else noop

        self.register_buffer('initted', torch.Tensor([not kmeans_init]))
        self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
        self.register_buffer('embed_avg', embed.clone())

        self.learnable_codebook = learnable_codebook
        if learnable_codebook:
            self.embed = nn.Parameter(embed)
        else:
            self.register_buffer('embed', embed)

    @torch.jit.ignore
    def init_embed_(self, data):
        if self.initted:
            return

        embed, cluster_size = kmeans(
            data,
            self.codebook_size,
            self.kmeans_iters,
            sample_fn = self.sample_fn,
            all_reduce_fn = self.kmeans_all_reduce_fn
        )

        self.embed.data.copy_(embed)
        self.embed_avg.data.copy_(embed.clone())
        self.cluster_size.data.copy_(cluster_size)
        self.initted.data.copy_(torch.Tensor([True]))

    def replace(self, batch_samples, batch_mask):
        for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):
            if not torch.any(mask):
                continue

            sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
            self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...')

    def expire_codes_(self, batch_samples):
        if self.threshold_ema_dead_code == 0:
            return

        expired_codes = self.cluster_size < self.threshold_ema_dead_code

        if not torch.any(expired_codes):
            return

        batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
        self.replace(batch_samples, batch_mask = expired_codes)

    @autocast(enabled = False)
    def forward(self, x):
        needs_codebook_dim = x.ndim < 4

        x = x.float()

        if needs_codebook_dim:
            x = rearrange(x, '... -> 1 ...')

        shape, dtype = x.shape, x.dtype
        flatten = rearrange(x, 'h ... d -> h (...) d')

        self.init_embed_(flatten)

        embed = self.embed if not self.learnable_codebook else self.embed.detach()

        dist = -torch.cdist(flatten, embed, p = 2)

        embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
        embed_ind = embed_ind.view(*shape[:-1])

        quantize = batched_embedding(embed_ind, self.embed)

        if self.training:
            cluster_size = embed_onehot.sum(dim = 1)

            self.all_reduce_fn(cluster_size)
            self.cluster_size.data.lerp_(cluster_size, 1 - self.decay)

            embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
            self.all_reduce_fn(embed_sum.contiguous())
            self.embed_avg.data.lerp_(embed_sum, 1 - self.decay)

            cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()

            embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
            self.embed.data.copy_(embed_normalized)
            self.expire_codes_(x)

        if needs_codebook_dim:
            quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))

        return quantize, embed_ind

In [23]:
project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()

In [None]:
# image forward test
x = torch.randn(32, 3, 16, 16)
x = rearrange(x, 'b c h w -> b (h w) c')
x = project_in(x)

In [19]:
t = torch.empty((3,4,10))
nn.init.kaiming_uniform_(t).shape

torch.Size([3, 4, 10])

In [20]:
def uniform_init(*shape):
    t = torch.empty(shape)
    nn.init.kaiming_uniform_(t)
    return t


num_codebooks = 10
codebook_size = 10
dim = 256
embed = uniform_init(num_codebooks, codebook_size, dim)

In [None]:
class VQ(nn.Module):
  def __init__(self):
    super().__init__()
    num_embeddings = 10
    embedding_dims = 32
    self._embedding = nn.Embedding(num_embeddings, embedding_dims)
    self._embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
    
  def forward(self, x):
    # embedidng_dim이 x의 shape와 똑같다고 가정
    x = x.permute(0, 2, 3, 1).contiguous()
    x = x.view(-1, embedding_dims)
    
    distances = torch.sum(x**2, dim=1, keepdim=True) + torch.sum(self._embedding.weight**2, dim=1) - 2*torch.matmul(x, self._embedding.weight.t())
    
    encoding_indicies = torch.argmin()

In [35]:
x = torch.randn(10, 32)
emb = torch.randn(5, 32)


sim = torch.matmul(x, emb.t())

# measure l2-normalized distance btw flattened encoder output and code words

dist = torch.sum(x**2, dim=1, keepdim=True) + torch.sum(emb**2, dim=1) -2 * sim
dist

tensor[10, 5] n=50 x∈[33.163, 94.940] μ=57.484 σ=12.957

In [34]:
torch.cdist(x, emb, p=1)

tensor[10, 5] n=50 x∈[27.652, 46.419] μ=36.437 σ=4.741

In [41]:
dist.shape
torch.argmin(dist, dim=1).shape

torch.Size([10])

In [42]:
torch.argmin(dist, dim=1)

tensor[10] i64 x∈[0, 4] μ=1.500 σ=1.650 [1, 0, 0, 3, 0, 4, 0, 2, 1, 4]

In [43]:
torch.argmin(dist, dim=1, keepdim=True)

tensor[10, 1] i64 x∈[0, 4] μ=1.500 σ=1.650 [[1], [0], [0], [3], [0], [4], [0], [2], [1], [4]]

In [44]:
torch.argmin(dist, dim=1).unsqueeze(1)

tensor[10, 1] i64 x∈[0, 4] μ=1.500 σ=1.650 [[1], [0], [0], [3], [0], [4], [0], [2], [1], [4]]

In [47]:
encoding_indicies = torch.argmin(dist, dim=1, keepdim=True)

In [46]:
encodings = torch.zeros(10, 5)
encodings

tensor[10, 5] [38;2;127;127;127mall_zeros[0m

In [49]:
# one-hot vector 처럼 만듬
encodings.scatter_(1, encoding_indicies, 1).v

tensor[10, 5] n=50 x∈[0., 1.000] μ=0.200 σ=0.404
tensor([[0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.]])

In [50]:
torch.matmul(encodings, emb).view(1, 10, 32)

tensor[1, 10, 32] n=320 x∈[-2.581, 2.787] μ=-0.053 σ=0.998

In [37]:
dist = torch.sum(x**2, dim=1, keepdim=True) + torch.sum(emb**2, dim=1) -2 * sim
dist

tensor[10, 5] n=50 x∈[33.163, 94.940] μ=57.484 σ=12.957

In [39]:
torch.norm(x, p=2)

tensor 16.544

In [None]:
class VQ(nn.Module):
  def __init__(self):
    self.emb = nn.Embedding(num_embeddings, embedding_dim)
    
  def forward(self, x):
    x = rearrange(x , 'b c h w -> (b h w) c')
    distannces = 

In [2]:
import torch

# Create a tensor with the original ball colors
ball_colors = torch.tensor([0, 1, 2, 3, 4, 2, 1, 0])

# Create a tensor with the indices of the balls you want to change
indices_to_change = torch.tensor([1, 3, 6])

# Create a tensor with the new color you want to change the balls to
new_color = torch.tensor([[5]])

# Use scatter_ to change the color of the specified balls
ball_colors.scatter_(0, indices_to_change, new_color)

print(ball_colors)

RuntimeError: Index tensor must have the same number of dimensions as src tensor

In [3]:
torch.manual_seed(7)

x = torch.rand(2, 5)
x

tensor([[0.5349, 0.1988, 0.6592, 0.6569, 0.2328],
        [0.4251, 0.2071, 0.6297, 0.3653, 0.8513]])

In [5]:
y = torch.zeros(3,5)
y

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [6]:
print(y.scatter_(0,torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 1]]),x))

tensor([[0.5349, 0.2071, 0.6297, 0.6569, 0.2328],
        [0.0000, 0.1988, 0.0000, 0.3653, 0.8513],
        [0.4251, 0.0000, 0.6592, 0.0000, 0.0000]])


In [3]:
torch.manual_seed(7)

x = torch.rand(2, 5)
x


tensor[2, 5] n=10 x∈[0.199, 0.851] μ=0.476 σ=0.225 [[0.535, 0.199, 0.659, 0.657, 0.233], [0.425, 0.207, 0.630, 0.365, 0.851]]

In [4]:
y = torch.zeros(3,5)
y

tensor[3, 5] [38;2;127;127;127mall_zeros[0m

In [7]:
y.scatter_(0,torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 1]]),1).v

tensor[3, 5] n=15 x∈[0., 1.000] μ=0.667 σ=0.488
tensor([[1., 1., 1., 1., 1.],
        [0., 1., 0., 1., 1.],
        [1., 0., 1., 0., 0.]])