In [1]:
# Import necessary PyTorch libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms


# Additional libraries for visualization and utilities
import matplotlib.pyplot as plt
import numpy as np
from unet_decoder import UNetDecoder

In [2]:
def get_device():
    """Selects the best available device for PyTorch computations.

    Returns:
        torch.device: The selected device.
    """

    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
print(f"using device: {device}")

using device: mps


In [7]:
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, Normalize, ToTensor,Resize

from torch.utils.data import DataLoader, random_split

# Define the transformation with resizing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the Omniglot dataset
dataset = datasets.MNIST(root='./data', download=True, transform=transform)

# Print the total number of images in the dataset
print(f"Total number of images in the dataset: {len(dataset)}")

# Splitting dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Print the number of images in the train and validation sets
print(f"Number of images in the training set: {len(train_dataset)}")
print(f"Number of images in the validation set: {len(val_dataset)}")

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

# Inspect the contents of the train_loader
train_batches = 0
for batch in train_loader:
    train_batches += 1

print(f"Number of batches in the training loader: {train_batches}")
print(f"Total number of images in the training loader: {train_batches * 128}")

Total number of images in the dataset: 60000
Number of images in the training set: 48000
Number of images in the validation set: 12000
Number of batches in the training loader: 375
Total number of images in the training loader: 48000


In [5]:
import torch

# Load the sampled data
sampled_data_path = 'mnist_gaussian_ddpm.pt'
sampled_data = torch.load(sampled_data_path)

# Extract the images from the dictionary
sampled_images = [sampled_data[key]['sampled'] for key in sampled_data]
sampled_images = torch.stack(sampled_images)  # Convert list to tensor

# Normalize the sampled images
sampled_images = sampled_images.float() / 255.0  # Scale back to [0, 1]
sampled_images = (sampled_images - 0.1307) / 0.3081  # Normalize using the same mean and std as MNIST

print(f"Loaded and normalized {sampled_images.size(0)} sampled images from {sampled_data_path}")


Loaded and normalized 4096 sampled images from mnist_gaussian_ddpm.pt
