<a href="https://colab.research.google.com/github/yjyjy131/Study_Deep_Learning/blob/main/Pytorch_Basic/Chapter8_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### deep k-means algorithm

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import linear_sum_assignment as linear_assignment

In [None]:
batch_size = 128
num_clusters = 10
latent_size = 10

In [None]:
trainset = torchvision.datasets.MNIST('./data/', download=True, train=True, transform=transforms.ToTensor())
testset = torchvision.datasets.MNIST('./data/', download=True, train=False, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
class Flatten(torch.nn.Module):
  def forward(self, x):
    batch_size = x.shape[0]
    return x.view(batch_size, -1)

class Deflatten(nn.Module):
  def __init__(self, k):
    super(Deflatten, self).__init__()
    self.k = k

  def forward(self, x):
    s = x.size()
    feature_size = int((s[1]//self.k)**.5)
    return x.view(s[0], self.k, feature_size, feature_size)

In [None]:
class Kmeans(nn.Module):
  def __init__(self, num_clusters, latent_size):
    super(Kmeans, self).__init__()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.num_clusters = num_clusters
    self.centroids = nn.Parameter(torch.rand((self.num_clusters, latent_size)).to(device))

  def argminl2distance(self, a, b):
    return torch.argmin(torch.sum((a-b)**2, dim=1), dim=0)

  def forward(self, x):
    y_assign = []
    for m in range(x.size(0)):
      h = x[m].expand(self.num_clusters, -1)
      assign = self.argminl2distance(h, self.centroids)
      y_assign.append(assign.item())
    return y_assign, self.centroids[y_assign]

In [None]:
class Encoder(nn.Module):
  def __init__(self, latent_size):
    super(Encoder, self).__init__()

    k = 16
    self.encoder = nn.Sequential(
        nn.Conv2d(1, k, 3, stride=2), 
        nn.ReLU(),
        nn.Conv2d(k, 2*k, 3, stride=2),
        nn.ReLU(),
        nn.Conv2d(2*k, 4*k, 3, stride=1),
        nn.ReLU(),
        Flatten(),
        nn.Linear(1024, latent_size),
        nn.ReLU()
    )

  def forward(self, x):
    return self.encoder(x)  
    s =x.size()
    feature_size = int((s[1]//self.k)**.5)
    return x.view(s[0], self.k, feature_size, feature_size)

In [None]:
class Decoder(nn.Module):
  def __init__(self, latent_size):
    super(Decoder, self).__init__()
    k = 16
    self.decoder = nn.Sequential(
        nn.Linear(latent_size, 1024),
        nn.ReLU(),
        Deflatten(4*k),
        nn.ConvTranspose2d(4*k, 2*k, 3, stride=1),
        nn.ReLU(),
        nn.ConvTranspose2d(2*k, k, 3, stride=2),
        nn.ReLU(),
        nn.ConvTranspose2d(k, 1, 3, stride=2, output_padding=1),
        nn.Sigmoid()                 
    )

  def forward(self, x):
    return self.decoder(x)

In [None]:
# 클러스터 라벨 재배치 함수
def cluster_acc(y_true, y_pred):
  y_true = np.array(y_true)
  y_pred = np.array(y_pred)
  D = max(y_pred.max(), y_true.max())+1
  w = np.zeros((D, D), dtype=np.int64)
  for i in range(y_pred.size):
    w[y_pred[i], y_true[i]] += 1
  ind = linear_assignment(w.max()-w)
  return sum([w[i, j] for i, j, in zip(ind[0], ind[1])]) * 1.0 / y_pred.size