In [None]:
import os 
import tensorflow as tf
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchmetrics.image.fid import FrechetInceptionDistance

import matplotlib.pyplot as plt

random_seed = 9292
torch.manual_seed(random_seed)
LATENT_DIM = 100
BATCH_SIZE = 64
NUM_EPOCHS = 500
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS = int(os.cpu_count() / 2)
lr = 3e-4
latent_dims = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
data_dir = './data'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

training = CIFAR10(data_dir, train=True, transform=transform, download=True)
testing = CIFAR10(data_dir, train=False, transform=transform, download=True)

train_dataloader = DataLoader(training, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(testing, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
rows = 8
columns = 6
for data, label in train_dataloader:
    fig, axs = plt.subplots(rows, columns, figsize=(20, 30))
    for i in range(rows*columns):
        ax = axs[i//columns, i%columns]
        img = (data[i].permute(1,2,0) + 1) / 2
        ax.imshow(img, cmap='Greys')
        ax.axis('off')
        classlabel = classes[label[i]]
        ax.set_title(classlabel)
        
    print(data[0].shape)
    break


plt.subplots_adjust(wspace=0.1, hspace=0)