In [None]:
import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 데이터로더 생성
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, drop_last=True)


In [None]:
import matplotlib.pyplot as plt

print(train_dataset[0][0])

plt.imshow(train_dataset[0][0].permute(1,2,0))  # 텐서의 차원을 조정하여 이미지로 변환
plt.title('Tensor Visualization')
plt.show()


In [None]:
from torch import nn

class Encoder(nn.Module):
    def __init__(self, data_dim, hid_dim, hid_dim2 , lat_dim):
        super(Encoder, self).__init__()
        self.data_dim = data_dim
        self.hid_dim = hid_dim
        self.hid_dim2 = hid_dim2
        self.lat_dim = lat_dim


        self.net=nn.Sequential(nn.Linear(self.data_dim, self.hid_dim), nn.ReLU(),
                      nn.Linear(self.hid_dim , self.hid_dim2), nn.ReLU(),
                      )
        self.mu_net = nn.Linear(self.hid_dim2 , self.lat_dim)
        self.logvar_net = nn.Linear(self.hid_dim2, self.lat_dim)

    def reparameterization(self, mu, logvar):

        epsilon = torch.randn_like(logvar)
        z = mu+ epsilon*torch.exp(0.5*logvar)
        return z

    def reparameterization2(self, mu, logvar):
        z = [ ]
        for i in range(len(mu)):
            epsilon = torch.randn(self.lat_dim)
            z_data= mu[i] + epsilon * torch.exp(0.5*logvar[i])
            z.append(z_data)
        z_batch =torch.stack(z, dim =0)
        return z_batch

    def forward(self, x):
        x= self.net(x)
        mu = self.mu_net(x)
        logvar = self.logvar_net(x)
        z = self.reparameterization2(mu, logvar)

        return z, mu, logvar

class Decoder(nn.Module):
    def __init__(self,  lat_dim, hid_dim, hid_dim2, data_dim):
        super(Decoder, self).__init__()
        self.data_dim = data_dim
        self.hid_dim = hid_dim
        self.hid_dim2 = hid_dim2
        self.lat_dim = lat_dim


        self.net=nn.Sequential(nn.Linear(self.lat_dim, self.hid_dim), nn.ReLU(),
                      nn.Linear(self.hid_dim , self.hid_dim2), nn.ReLU(),
                      nn.Linear(self.hid_dim2 , self.data_dim))

    def forward(self, x):
        x = self.net(x)
        pred = nn.Sigmoid()(x)
        return pred

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

In [None]:
import torch.nn.functional as F

epoch = 200

encoder = Encoder(28*28, 400, 200, 2).to(device)
decoder = Decoder(2, 200, 400, 28*28).to(device)

parameters = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(parameters, lr=0.001)

encoder.train()
decoder.train()

loss_stack = []

for i in range(epoch):
    loss_r =0
    for x_batch, _ in train_loader:
        x_batch = x_batch.view(64, -1).to(device)
        z, mu, logvar = encoder(x_batch)

        recon_x_batch = decoder(z)

        reconst_loss = nn.BCELoss(reduction='sum')(recon_x_batch, x_batch)
        regular_loss = 0.5 * torch.sum(mu**2 + torch.exp(logvar) - logvar - 1)


        loss = reconst_loss + regular_loss
        loss_r += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    real_loss  = loss_r /len(train_dataset)
    loss_stack.append(real_loss)

    print(f" For epoch{i+1}, Loss : {real_loss}")
    print("-"*30)

In [None]:
check_num_image = 10
z = torch.randn(check_num_image, 2).to(device)


# x, y 각각에 대해 -1부터 1까지 0.1 간격으로 균일하게 숫자를 생성
x_values = torch.linspace(-1, 1, 20).to(device)
y_values = torch.linspace(-1, 1, 20).to(device)

# 생성된 x_values와 y_values를 조합하여 2차원 텐서 생성
grid_tensor = torch.stack(torch.meshgrid(x_values, y_values), dim=-1)

# 2차원 텐서를 flatten하여 최종 결과 텐서 생성
result_tensor = grid_tensor.view(-1, 2)

sampled_images = decoder(result_tensor).to(device).view(400, 28, 28)

# Inference 결과 시각화
fig = plt.figure(figsize=(10, 10))
for idx, img in enumerate(sampled_images):
    ax = fig.add_subplot(20, 20, idx+1)
    img = img.cpu()
    img = img.detach().numpy()
    ax.imshow(img, cmap='gray')


In [None]:
z= torch.tensor([0.3 , 0.2]).to(device)

a = decoder(z).to(device).view(28, 28)
fig = plt.figure(figsize=(1, 1))

img = a.cpu()
img = img.detach().numpy()
plt.imshow(img, cmap='gray')