# Mac-Torch-Gpu-VAE
Use MPS (provided in [Pytorch v1.12.0+](https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/)) to accelerate the calculations in ARM Macbook.

## 1. Define Parameters
Parameters for training the VAE.

In [21]:
import argparse

parser = argparse.ArgumentParser(description='MAC Torch GPU VAE Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=3, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--no-GPU', action='store_true', default=False,
                    help='disables GPU training') # Use CPU or GPU? True:False;
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args(args=[])

Parameters for using GPU

In [22]:
import torch

# Mac: mps
args.GPU = not args.no_GPU and torch.backends.mps.is_available()
# Linux: cuda
# args.GPU = not args.no_GPU and torch.cuda.is_available()

print("args.GPU:" + str(args.GPU))

torch.manual_seed(args.seed) # Fix inital parameters of model
# Mac: mps
device = torch.device("mps" if args.GPU else "cpu") # Choose CPU or GPU device
# Linux: cuda
# device = torch.device("cuda" if args.GPU else "cpu")

kwargs = {"num_workers":2, "pin_memory":True} if args.GPU else {}
print("kwargs:" + str(kwargs))

args.GPU:True
kwargs:{'num_workers': 2, 'pin_memory': True}


## 2. VAE Model Development

In [23]:
from torch import nn, optim
from torch.nn import functional as F

# VAE Model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(28*28, 400)
        self.fc21 = nn.Linear(400, 20) # Mu
        self.fc22 = nn.Linear(400, 20) # log(Var)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 28*28)
    
    def encoder(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc21(h1)
        logvar = self.fc22(h1)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encoder(x.view(-1, 28*28))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr = 1e-4) # Adam optimizer

# Loss_function = -ELBO = 
def loss_function(x_hat, x, mu, logvar):
    # Reconstruction Term
    Recon = F.binary_cross_entropy(x_hat, x.view(-1, 28*28), reduction='sum')

    # Regularization Loss 
    KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return Recon + KL

## 3. Train and Test Processes

In [24]:
import torch.utils.data
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Dataloader for training and testing (read img as chw format)
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train = True, download = False, transform = transforms.ToTensor()),
    batch_size = args.batch_size, shuffle = True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train = False, transform = transforms.ToTensor()),
    batch_size = args.batch_size, shuffle = False, **kwargs)

def train(epoch):
    model.train() # Enable: batch normalization + drop out
    train_loss = 0 # total loss in an epoch
    for batch_idx, (data, _) in enumerate(train_loader):
        # Initialization
        data = data.to(device)
        optimizer.zero_grad()
        # Loss backward
        x_hat, mu, logvar = model(data)
        loss = loss_function(x_hat, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        # Optimize
        optimizer.step()
        # Print Information (Batch)
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx/len(train_loader),
                loss.item() / len(data)))
    
    # Print Information (Epoch)
    print("[Epoch]:{}, Average Loss:{:.4f}".format(
        epoch, train_loss / len(train_loader.dataset)))

def test(epoch):
    model.eval() # Disable: batch normalization + drop out
    test_loss = 0 # total loss in an epoch
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            # Initialization
            data = data.to(device)
            # Loss Calculation
            x_hat, mu, logvar = model(data)
            test_loss += loss_function(x_hat, data, mu, logvar).item()
            # Print Information (The First Batch)
            if i == 0:
                n = min(data.size(0), 8) # n samples shown
                # Compare x and x_hat
                comparison = torch.cat([data[:n],
                    x_hat.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(), './results/rec_' + str(epoch) + '.png', nrows=n)
    
    test_loss /= len(test_loader.dataset)
    print("Test Loss:{:.6f}".format(test_loss))


## 4. Main()

In [25]:
if __name__ == "__main__":
    for epoch in range(1, args.epochs +1):
        train(epoch)
        test(epoch)
        # Test Decoder
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                './results/sample_'+str(epoch)+'.png')       

[Epoch]:1, Average Loss:260.1906
Test Loss:196.609915
[Epoch]:2, Average Loss:177.6280
Test Loss:163.634380
[Epoch]:3, Average Loss:157.1286
Test Loss:149.357027
