In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import random_split
from pathlib import Path

# setup TensorBoard

In [2]:
from dataclasses import dataclass

@dataclass
class Config:
    batch_size: int = 64
    dataset_size: int = 2048
    start_epoch: int = 0
    n_epochs: int = 1000
    lr: float = 0.0002
    relu_slope: float = 0.2
    dropout: float = 0.2
    # randn dim
    in_dim: int = 100
    # 3 channels, 218x178
    out_dim: int = 3 * 218 * 178
    save_epoch_interval: int = 50
    exp_id: str = 'mnist_gan_exp_1'

config = Config()

In [3]:
from torch.utils.tensorboard import SummaryWriter
import datetime

def setup_tensorboard(id):
    # template = "%Y-%m-%d_%H-%M-%S"
    print(f'logdir=runs/{id}')
    writer = SummaryWriter(f'runs/{id}')
    return writer

writer = setup_tensorboard(config.exp_id)

logdir=runs/mnist_gan_exp_1


## Dataset


In [4]:
def clean_cache():
    # empty cache
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

# data = datasets.MNIST(root='./dataset', download=True, transform=transform)
# <https://stackoverflow.com/questions/70896841/error-downloading-celeba-dataset-using-torchvision>
data = datasets.CelebA(root='./dataset', download=True, transform=transform)

# load dataset partially
if config.dataset_size > 0:
    data, _ = random_split(data, [config.dataset_size, len(data) - config.dataset_size])

dataloader = DataLoader(data, batch_size=config.batch_size, shuffle=True, drop_last=True)

print(f'data={len(dataloader) * config.batch_size}, batch_size={config.batch_size}, n_epochs={config.n_epochs}')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clean_cache()


Files already downloaded and verified
data=2048, batch_size=64, n_epochs=1000
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |

# Model

In [6]:
class Generator(nn.Module):
    def __init__(self, config: Config):
        super(Generator, self).__init__()
        sizes = [config.in_dim, 256, 512, 1024, config.out_dim]
        self.slope = config.relu_slope
        self.fc1 = nn.Linear(sizes[0], sizes[1])
        self.fc2 = nn.Linear(sizes[1], sizes[2])
        self.fc3 = nn.Linear(sizes[2], sizes[3])
        self.fc4 = nn.Linear(sizes[3], sizes[4])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        slope = self.slope
        x = F.leaky_relu(self.fc1(x), slope)
        x = F.leaky_relu(self.fc2(x), slope)
        x = F.leaky_relu(self.fc3(x), slope)
        x = F.tanh(self.fc4(x))
        return x


class Discriminator(nn.Module):
    def __init__(self, config: Config):
        super(Discriminator, self).__init__()
        sizes = [config.out_dim, 1024, 512, 256, 1]
        self.slope = config.relu_slope
        self.dropout = config.dropout
        self.fc1 = nn.Linear(sizes[0], sizes[1])
        self.fc2 = nn.Linear(sizes[1], sizes[2])
        self.fc3 = nn.Linear(sizes[2], sizes[3])
        self.fc4 = nn.Linear(sizes[3], sizes[4])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        slope = self.slope
        dropout = self.dropout
        x = F.leaky_relu(self.fc1(x), slope)
        x = F.dropout(x, dropout)
        x = F.leaky_relu(self.fc2(x), slope)
        x = F.dropout(x, dropout)
        x = F.leaky_relu(self.fc3(x), slope)
        x = F.dropout(x, dropout)
        x = F.sigmoid(self.fc4(x))
        return x

# Train

In [7]:

def save_model(generator: Generator, discriminator: Discriminator, optim_generator, optim_discriminator, config: Config, epoch: int):
    print(f'save models @ epoch={epoch}')
    torch.save({
        'epoch': epoch,
        'generator': generator.state_dict(),
        'discriminator': discriminator.state_dict(),
        'optim_generator': optim_generator.state_dict(),
        'optim_discriminator': optim_discriminator.state_dict(),
    }, f'models/{config.exp_id}_{epoch}.pt')

def create_model(config: Config):
    generator = Generator(config).to(device)
    discriminator = Discriminator(config).to(device)

    optim_generator = optim.AdamW(generator.parameters(), lr=config.lr)
    optim_discriminator = optim.Adam(discriminator.parameters(), lr=config.lr)

    return generator, discriminator, optim_generator, optim_discriminator


def load_model(config: Config, start_epoch: int):
    checkpoint = torch.load(f'models/{config.exp_id}_{start_epoch}.pt')

    config.start_epoch = checkpoint['epoch'] + 1
    generator = Generator(config).to(device)
    generator.load_state_dict(checkpoint['generator'])
    generator.train()
    discriminator = Discriminator(config).to(device)
    discriminator.load_state_dict(checkpoint['discriminator'])
    discriminator.train()

    optim_generator = optim.AdamW(generator.parameters(), lr=config.lr)
    optim_generator.load_state_dict(checkpoint['optim_generator'])
    optim_discriminator = optim.Adam(discriminator.parameters(), lr=config.lr)
    optim_discriminator.load_state_dict(checkpoint['optim_discriminator'])
    
    return generator, discriminator, optim_generator, optim_discriminator

def train_discriminator(
    generator: Generator, 
    discriminator: Discriminator, 
    optim_discriminator, 
    x: torch.Tensor,
    criterion,
    config: Config,
    step: int,
):
    discriminator.zero_grad()

    x_real, y_real = x.view(-1, config.out_dim).to(device), torch.ones(config.batch_size, 1).to(device)
    x_real, y_real = Variable(x_real), Variable(y_real)
    
    d_output = discriminator(x_real)
    # print(f'd_output: {d_output.shape}, y_real: {y_real.shape}')
    loss_real = criterion(d_output, y_real)

    # train discriminator with fake data
    z = Variable(torch.randn(config.batch_size, config.in_dim, device = device))
    x_fake, y_fake = generator(z), torch.zeros(config.batch_size, 1).to(device)
    
    d_output = discriminator(x_fake)
    loss_fake = criterion(d_output, y_fake)

    loss = loss_real + loss_fake
    loss.backward()
    optim_discriminator.step()
    l = loss.item()
    del loss
    writer.add_scalar('loss/discriminator', l, step)
    
def train_generator(
    generator: Generator,
    discriminator: Discriminator,
    optim_generator,
    criterion,
    config: Config,
    step: int,
):
    generator.zero_grad()
    batch_size = config.batch_size
    in_dim = config.in_dim
    z = Variable(torch.randn(batch_size, in_dim, device = device))
    y = Variable(torch.ones(batch_size, 1).to(device))

    g_output = generator(z)
    d_output = discriminator(g_output)
    loss_generator = criterion(d_output, y)

    loss_generator.backward()
    optim_generator.step()
    loss = loss_generator.item()
    writer.add_scalar('loss/generator', loss, step)
    del loss_generator


def train(
    writer: SummaryWriter,
    generator: Generator, 
    discriminator: Discriminator,
    optim_generator,
    optim_discriminator,
    dataloader: DataLoader,
    config: Config,
):
    criterion = nn.BCELoss()
    for epoch in range(config.start_epoch, config.start_epoch + config.n_epochs):
        now = datetime.datetime.now
        print(f'[{now()}] Epoch {epoch}')
        for i, (x, _) in enumerate(dataloader):
            x = x.to(device)
            # print(f'[{now()}] Epoch {epoch} batch {i}')
            step = epoch * len(dataloader) + i
            train_discriminator(generator, discriminator, optim_discriminator, x, criterion, config, step)
            train_generator(generator, discriminator, optim_generator, criterion, config, step)

        # generate image
        with torch.no_grad():
            # input
            z = torch.randn(config.batch_size, config.in_dim).to(device)
            images = generator(z).view(-1, *shape)
            writer.add_images(f'generated_image', images, epoch)

        if epoch % config.save_epoch_interval == 0:
            save_model(generator, discriminator, optim_generator, optim_discriminator, config, epoch)

In [8]:
clean_cache()

# (channel, width, height)
shape = iter(dataloader).next()[0].shape[1:]
generator, discriminator, optim_generator, optim_discriminator = load_model(config, 2)
# generator, discriminator, optim_generator, optim_discriminator = create_model(config)
train(writer, generator, discriminator, optim_generator, optim_discriminator, dataloader, config)

[2022-06-16 02:52:29.556031] Epoch 3




save models @ epoch=3


RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\caffe2\serialize\inline_container.cc:300] . unexpected pos 1440127744 vs 1440127632

# Test

In [None]:
generator, _, _ , _ = load_model(config, 0)
generator.eval()

z = torch.randn(config.batch_size, config.in_dim).to(device)
images = generator(z).view(-1, *shape)

print(images)

Generator(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=116412, bias=True)
)
Discriminator(
  (fc1): Linear(in_features=116412, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)
tensor([[[[-1.1310e-02,  2.4825e-02, -3.0744e-02,  ...,  3.0162e-02,
           -8.7601e-03,  4.7575e-02],
          [-7.2179e-02,  4.0158e-02,  1.0752e-02,  ...,  1.9712e-02,
           -2.4964e-02,  5.5346e-02],
          [ 4.9146e-02, -1.0183e-01, -3.4617e-02,  ...,  1.3988e-02,
            8.6680e-03,  1.9955e-02],
          ...,
          [-3.6754e-02, -1.9263e-02, -3.5289e-02,  ..., -1.1569e-02,
            3.4800e-02, -6.9653e-02],
          [-3.3608e-02,