In [1]:
# Import necessary PyTorch libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms


# Additional libraries for visualization and utilities
import matplotlib.pyplot as plt
import numpy as np
from unet_decoder import UNetDecoder

In [2]:
def get_device():
    """Selects the best available device for PyTorch computations.

    Returns:
        torch.device: The selected device.
    """

    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = get_device()
print(f"using device: {device}")

using device: mps


In [3]:
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, Normalize, ToTensor,Resize

from torch.utils.data import DataLoader, random_split

# Define the transformation with resizing
transform = transforms.Compose([
    Resize((28,28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the Omniglot dataset
dataset = datasets.Omniglot(root='./data', download=True, transform=transform, background=True)

# Print the total number of images in the dataset
print(f"Total number of images in the dataset: {len(dataset)}")

# Splitting dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

Files already downloaded and verified
Total number of images in the dataset: 19280


In [4]:
# Extract images from the MNIST training data
omniglot_images = [val_dataset[i][0] for i in range(len(val_dataset))]
omniglot_images = torch.stack(omniglot_images)  # Convert list to tensor

print(f"Loaded {omniglot_images.size(0)} Omniglot images from the training dataset")

Loaded 3856 Omniglot images from the training dataset


In [10]:
import torch

# Load the sampled data
sampled_data_path = 'omniglot_cold_l1_alg2.pt'
sampled_data = torch.load(sampled_data_path)

# Extract the images from the dictionary
sampled_images = [sampled_data[key]['sampled'] for key in sampled_data]
sampled_images = torch.stack(sampled_images)  # Convert list to tensor

# Normalize the sampled images
sampled_images = sampled_images.float() / 255.0  # Scale back to [0, 1]
sampled_images = (sampled_images - 0.5) / 0.5  # Normalize using the same mean and std as MNIST

print(f"Loaded and normalized {sampled_images.size(0)} sampled images from {sampled_data_path}")

Loaded and normalized 4096 sampled images from omniglot_cold_l1_alg2.pt


In [11]:
from torch.utils.data import TensorDataset

# Create labels for the sampled data (0 for generated images)
sampled_labels = torch.zeros(sampled_images.size(0), dtype=torch.long)

# Create labels for the original MNIST data (1 for original images)
omniglot_labels = torch.ones(omniglot_images.size(0), dtype=torch.long)

# Combine the images and labels into a single dataset
combined_images = torch.cat((sampled_images, omniglot_images), dim=0)
combined_labels = torch.cat((sampled_labels, omniglot_labels), dim=0)

# Create a permutation of indices
indices = torch.randperm(len(combined_images))

# Apply permutation to shuffle the dataset
shuffled_images = combined_images[indices]
shuffled_labels = combined_labels[indices]

# Create a TensorDataset with the shuffled data
combined_dataset = TensorDataset(shuffled_images, shuffled_labels)

# Splitting dataset into training and validation sets
train_size = int(0.8 * len(combined_dataset))
val_size = len(combined_dataset) - train_size
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])

# Print the number of images in the train and validation sets
print(f"Number of images in the training set: {len(train_dataset)}")
print(f"Number of images in the validation set: {len(val_dataset)}")

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Inspect the contents of the train_loader
train_batches = 0
for batch in train_loader:
    train_batches += 1

print(f"Number of batches in the training loader: {train_batches}")
print(f"Total number of images in the training loader: {train_batches * 64}")

Number of images in the training set: 6361
Number of images in the validation set: 1591
Number of batches in the training loader: 100
Total number of images in the training loader: 6400


In [12]:
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64*7*7)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [13]:
# Define the validate function
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return val_loss / len(val_loader), accuracy

# Initialize the model, loss function, and optimizer
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the neural network with validation
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    val_loss, val_accuracy = validate(model, val_loader, criterion, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")

print("Training complete")

Epoch [1/50], Loss: 0.3088, Val Loss: 0.2437, Val Accuracy: 90.57%
Epoch [2/50], Loss: 0.2460, Val Loss: 0.2896, Val Accuracy: 88.06%
Epoch [3/50], Loss: 0.2508, Val Loss: 0.2353, Val Accuracy: 90.89%
Epoch [4/50], Loss: 0.2353, Val Loss: 0.2459, Val Accuracy: 90.45%
Epoch [5/50], Loss: 0.2299, Val Loss: 0.2365, Val Accuracy: 91.07%
Epoch [6/50], Loss: 0.2275, Val Loss: 0.2387, Val Accuracy: 91.33%
Epoch [7/50], Loss: 0.2226, Val Loss: 0.2332, Val Accuracy: 91.14%
Epoch [8/50], Loss: 0.2136, Val Loss: 0.2660, Val Accuracy: 89.00%
Epoch [9/50], Loss: 0.2120, Val Loss: 0.2476, Val Accuracy: 90.13%
Epoch [10/50], Loss: 0.1998, Val Loss: 0.2507, Val Accuracy: 90.32%
Epoch [11/50], Loss: 0.1936, Val Loss: 0.2612, Val Accuracy: 89.31%
Epoch [12/50], Loss: 0.1909, Val Loss: 0.2500, Val Accuracy: 90.26%
Epoch [13/50], Loss: 0.1742, Val Loss: 0.2514, Val Accuracy: 90.13%
Epoch [14/50], Loss: 0.1570, Val Loss: 0.2600, Val Accuracy: 90.19%
Epoch [15/50], Loss: 0.1595, Val Loss: 0.2677, Val Accura