### 1. 모델 정의

MNIST를 입력으로 하는 (BATCH, 28*28)

2D Conv Encoder & Decoder로 이루어진 ConvVAE, ConvAE와

Linear Encoder & Decoder로 이루어진 FlattenVAE, FlattenAE 정의

In [None]:
# src/model.py

import torch.nn.functional as F
import torch.nn as nn
import torch

class ConvVAE(nn.Module):
    def __init__(self, latent_dim=4):
        super().__init__()

        # 2D-Conv Encoder (shape : [Batches, channels, image(h), image(w)])
        self.encoder = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='zeros'
                ), # shape : (B, 1, 28, 28) -> (B, 16, 14, 14)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16,
                out_channels=64,
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='zeros'
                ), # shape : (B, 16, 14, 14) -> (B, 64, 7, 7)
            nn.ReLU()
        )

        # avg of Latent vecs (shape : [Batches, latent_dim])
        self.fc_mu = nn.Linear(
                        in_features=64*7*7, 
                        out_features=latent_dim
                        )

        # ln(variance) of Latent vecs (shape: [Batches, latent_dim])
        self.fc_logvar = nn.Linear(
                            in_features=64*7*7, 
                            out_features=latent_dim
                            )

        # 2D-Conv Decoder (shape : [Batches, ...SAME SHAPE WITH ORIGINAL DATA...])
        self.fc_decoder = nn.Linear(
                                in_features=latent_dim, 
                                out_features=64*7*7
                                )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=64, 
                out_channels=16, 
                kernel_size=4, 
                stride=2, 
                padding=1,
                padding_mode='zeros'
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=16, 
                out_channels=1, 
                kernel_size=4, 
                stride=2, 
                padding=1,
                padding_mode='zeros'
            ),
            nn.Sigmoid()
        )

        # Params Init(Xavier uniform initialization)
        self._init_weights()

    
    def encode(self, x):
        encoded_x = self.encoder(x)
        encoded_x_flatten = encoded_x.reshape(encoded_x.size(0), -1) # shape : (Batches, flatten_len of sample)
        return self.fc_mu(encoded_x_flatten), self.fc_logvar(encoded_x_flatten)


    def reparameterize(self, mu, logvar):
        std_dev = torch.exp(0.5 * logvar) # standard deviation of Latent vecs
        eps = torch.randn_like(std_dev)
        
        return mu + eps * std_dev
    

    def decode(self, latent_z):
        preprocessed_z_in_2d = self.fc_decoder(latent_z).view(-1, 64, 7, 7) # shape : (Batches, 64, 7, 7)
       
        return self.decoder(preprocessed_z_in_2d)
    

    def forward(self, x):
        mu, logvar = self.encode(x)
        latent_z = self.reparameterize(mu, logvar)
        reconst_x = self.decode(latent_z)

        return reconst_x, mu, logvar, latent_z
    

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)



class FlattenVAE(nn.Module):
    def __init__(self, latent_dim=4):
        super().__init__()

        # Linear Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )

        # avg of Latent vecs
        self.fc_mu = nn.Linear(128, latent_dim)

        # ln(variance) of Latent vecs
        self.fc_logvar = nn.Linear(128, latent_dim)

        # Linear Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid()
        )

        # Params Init(Xavier uniform initialization)
        self._init_weights()


    def encode(self, x):
        encoded_x = self.encoder(x)

        return self.fc_mu(encoded_x), self.fc_logvar(encoded_x)


    def reparameterize(self, mu, logvar):
        std_dev = torch.exp(0.5 * logvar) # standard deviation of Latent vecs
        eps = torch.randn_like(std_dev)
        
        return mu + eps * std_dev
    

    def decode(self, latent_z):
        decoded_x = self.decoder(latent_z)
 
        return decoded_x.view(-1, 1, 28, 28)


    def forward(self, x):
        mu, logvar = self.encode(x)
        latent_z = self.reparameterize(mu, logvar)
        reconst_x = self.decode(latent_z)

        return reconst_x, mu, logvar, latent_z
    

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear,)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)



class ConvAE(nn.Module):
    def __init__(self, latent_dim=4):
        super().__init__()

        # 2D-Conv Encoder (shape : [Batches, channels, image(h), image(w)])
        self.encoder = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='zeros'
                ), # shape : (B, 1, 28, 28) -> (B, 16, 14, 14)
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16,
                out_channels=64,
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='zeros'
                ), # shape : (B, 16, 14, 14) -> (B, 64, 7, 7)
            nn.ReLU()
        )

        # Latent vecs (shape : [Batches, latent_dim])
        self.fc_z = nn.Linear(
                        in_features=64*7*7, 
                        out_features=latent_dim
                        )

        # 2D-Conv Decoder (shape : [Batches, ...SAME SHAPE WITH ORIGINAL DATA...])
        self.fc_decoder = nn.Linear(
                                in_features=latent_dim, 
                                out_features=64*7*7
                                )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=64, 
                out_channels=16, 
                kernel_size=4, 
                stride=2, 
                padding=1,
                padding_mode='zeros'
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels=16, 
                out_channels=1, 
                kernel_size=4, 
                stride=2, 
                padding=1,
                padding_mode='zeros'
            ),
            nn.Sigmoid()
        )

        # Params Init(Xavier uniform initialization)
        self._init_weights()   


    def encode(self, x):
        encoded_x = self.encoder(x)
        encoded_x_flatten = encoded_x.reshape(encoded_x.size(0), -1) # shape : (Batches, flatten_len of sample)

        return self.fc_z(encoded_x_flatten)
    

    def decode(self, latent_z):
        preprocessed_z_in_2d = self.fc_decoder(latent_z).view(-1, 64, 7, 7) # shape : (Batches, 64, 7, 7)
        
        return self.decoder(preprocessed_z_in_2d)


    def forward(self, x):
        latent_z = self.encode(x)
        reconst_x = self.decode(latent_z)

        return reconst_x, None, None, latent_z
    

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)



class FlattenAE(nn.Module):
    def __init__(self, latent_dim=4):
        super().__init__()

        # Linear Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

        # Linear Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid()
        )

        # Params Init(Xavier uniform initialization)
        self._init_weights()

    
    def encode(self, x):
        return self.encoder(x)
    

    def decode(self, latent_z):
        decoded_x = self.decoder(latent_z)
        return decoded_x.view(-1, 1, 28, 28) 
    

    def forward(self, x):
        latent_z = self.encode(x)
        reconst_x = self.decode(latent_z) 

        return reconst_x, None, None, latent_z
    

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear,)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

### 2. Loss function 정의

입력 데이터 x와 복원 데이터 reconst_x 사이의 차이를 계산하여 모델이 데이터를 얼마나 잘 복원했는지(BCE 재구성 손실 계산, 데이터 충실도)

잠재 공간 분포를 정규 분포와 얼마나 가깝게 유지했는지(KL Divergence 손실 계산, 규제)

를 ${\beta}$를 조절(${\beta}$스케줄링)하며 KL항의 영향력을 학습 진행에 따라 조정할 수 있는 Loss Function 정의


한줄요약 : 모델로 하여금 "원본이미지를 충실히 복원하렴" + "잠재 변수 분포 정규화 시 균형도 충실히 맞추렴"을 주문

In [None]:
# src/loss.py

import torch.nn.functional as F
import torch
import math

def reconst_loss(reconst_x, x):
    loss = F.binary_cross_entropy(reconst_x, x, reduction='sum')

    return loss / x.size(0) # sum of BCE loss / batch_size


def kl_divergence(mu, logvar):
    kl = -0.5 * torch.sum( 1+logvar-mu.pow(2)-logvar.exp() )

    return kl / mu.size(0) # sum of KLD / batch_size


def beta(config, epoch):
    if config['beta_schedule']['type'] == 'linear':
        t = min(1.0, epoch/max(1, config['beta_schedule']['warmup_epochs']))
        
        return 1.0 + t*(config['beta_schedule']['max_beta']-1.0)
    
    elif config['beta_schedule']['type'] == 'cosine':
        T = max(1.0, config['beta_schedule']['warmup_epochs'])
        t = min(1.0, epoch / T)
        
        return 1.0 + 0.5*(1-math.cos(math.pi * t))*(config['beta_schedule']['max_beta'] - 1.0)
    
    return config['model']['beta']

[참고] BCE + ${\beta\cdot}$ KLD는 이후 vae_train의 학습과정 중 구현

### 3. MNIST 데이터셋 준비하기

MNIST 데이터셋을 각각 Train, Test용 DataLoader로 준비

In [None]:
# src/dataset.py

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

def get_loaders(root, name="MNIST", batch_size=128, num_workers=0):
    tfm = transforms.ToTensor()
    dataset = datasets.MNIST if name=="MNIST" else datasets.FashionMNIST
    train = dataset(root, train=True, download=True, transform=tfm)
    test  = dataset(root, train=False, download=True, transform=tfm)
    
    train_loader= DataLoader(
        dataset=train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )
    test_loader = DataLoader(
        dataset=test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )
    return train_loader, test_loader

### 4. 학습

제곧내

In [None]:
# src/vae_train.py

from loss import *
from visualization import *
from tqdm.auto import tqdm
import torch
import os

def vae_train(model, train_loader, test_loader, config, device="cuda"):
    model.train()

    device = torch.device(config['device'])
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])

    out_dir = config['paths']['out_dir']
    os.makedirs(out_dir, exist_ok=True)

    for epoch in tqdm(range(1, config['train']['epochs']+1), total=len(train_loader), desc=f"[EPOCH {epoch}] training..."):
        total = 0
        reconst_total = 0
        for x, _ in train_loader:
            x = x.to(device)
            # forward
            reconst_x, mu, logvar, latent_z = model(x)

            # loss
            reconst_loss_term = reconst_loss(reconst_x, x)
            beta_term = beta(config, epoch)
            kl_divergence_term = kl_divergence(mu, logvar)
            loss = reconst_loss_term + beta_term*kl_divergence_term
            
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total += loss.item()
            reconst_total += reconst_loss_term.item()
        
        print(f"[{epoch}] loss={total/len(train_loader):.4f} (beta={beta_term:.2f})")
        print(f"[{epoch}] reconst_loss={reconst_total/len(train_loader):.4f} (beta={beta_term:.2f})")