In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.nn.functional as F

import matplotlib.pyplot as plt

from collections import OrderedDict

import pandas as pd
import seaborn as sns

from tqdm import tqdm

def get_cuda_device_or_cpu():
  if torch.cuda.is_available():
    cuda_count = torch.cuda.device_count()
    
    no = 0
    mem_available = 0

    for i in range(cuda_count):
      tmp_available = torch.cuda.mem_get_info(i)[0]
      if mem_available < tmp_available:
        no = i
        mem_available = tmp_available
    return f'cuda:{no}'
  return 'cpu'

def get_model(dim_z: int, kind='normal_ae', batch_size=128, less_than=10):

  # Download training data from open datasets.
  training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
  )

  training_data.data = training_data.data[training_data.targets < less_than]
  training_data.targets = training_data.targets[training_data.targets < less_than]

  # Create data loaders.
  dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)

  if kind=='stochastic_ae':
    encoder = nn.Sequential(
      nn.Linear(784, 512),
      nn.Tanh(),
      nn.Linear(512, 128),
      nn.Tanh(),
      nn.Linear(128, 64),
      nn.Tanh(),
      nn.Linear(64, dim_z),
    )

    decoder = nn.Sequential(
      nn.Linear(2, 64),
      nn.Tanh(),
      nn.Linear(64, 128),
      nn.Tanh(),
      nn.Linear(128, 512),
      nn.Tanh(),
      nn.Linear(512, 784),
      nn.Sigmoid(),
    )

  model = nn.Sequential(OrderedDict([
            ('encoder', encoder),
            ('decoder', decoder),
          ]))

  optimizer = torch.optim.Adam(model.parameters())

  return model, dataloader, optimizer

def train(model, dataloader, optimizer, run, epochs=5):
  
  device = get_cuda_device_or_cpu()
  print(device)
  
  model.to(device)
  model.train()

  hist = torch.zeros(0)

  for _ in tqdm(range(epochs)):
    tmp = run(model, dataloader, optimizer)
    hist = torch.cat([hist, tmp])
  
  return hist



In [2]:
def run_stochastic(model, dataloader, optimizer):

  hist = torch.zeros(len(dataloader))

  model.train()

  for batch, (x, y) in enumerate(dataloader):
    x = x.view([-1, 28*28]).to(device)

    h = model.get_submodule('encoder')(x)
    mu = h[:, :2]
    sigma_0 = h[:, 2:]
    sigma = F.softplus(sigma_0)
    z = torch.rand_like(mu, device=device) * sigma + mu

    pred = model.get_submodule('decoder')(z)
    loss = F.binary_cross_entropy(pred, x)

    # Backpropagation
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    hist[batch] = loss.item()

  return hist

In [3]:
model, dataloader, optimizer = get_model(dim_z=4, kind='stochastic_ae')
hist = train(model, dataloader, optimizer, run_stochastic, 10)
plt.plot(hist)
plt.show()

cuda:1


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [44]:
get_cuda_device_or_cpu()

'cuda:1'

In [43]:
torch.cuda.current_device()

0

In [4]:
!nvidia-smi

Tue Jun 20 18:43:11 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:65:00.0 Off |                    0 |
| N/A   38C    P0    63W / 300W |  80747MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:B3:00.0 Off |                    0 |
| N/A   45C    P0   251W / 300W |  12982MiB / 81920MiB |     95%      Default |
|       

In [5]:
!nvidia-smi

Wed Jun 21 09:21:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  Off  | 00000000:65:00.0 Off |                    0 |
| N/A   31C    P0    60W / 300W |  80753MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  Off  | 00000000:B3:00.0 Off |                    0 |
| N/A   31C    P0    56W / 300W |   1443MiB / 81920MiB |      0%      Default |
|       