In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
# Import necessary PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms

# Set the device to MPS if available, else CPU
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Additional libraries for visualization and utilities
import matplotlib.pyplot as plt
import numpy as np

Using device: mps


In [2]:
# Import the adapted Echo noise functions
from echo import echo_sample, echo_loss

In [3]:
from torchvision import datasets
from torch.utils.data import DataLoader, random_split

# Define transformations: Resize if needed and normalize the data
transform = transforms.Compose([
    # transforms.Resize((28, 28)), # Uncomment if resizing is needed
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Splitting dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False)

print("Data loaders created for training and validation.")

Data loaders created for training and validation.


In [4]:
class Encoder(nn.Module):
    def __init__(self, input_shape, latent_dims):
        super(Encoder, self).__init__()
        self.input_shape = input_shape
        self.latent_dims = latent_dims
        
        self.conv1 = nn.Conv2d(input_shape[0], latent_dims[0], kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(latent_dims[0], latent_dims[1], kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(latent_dims[1], latent_dims[2], kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(latent_dims[2], latent_dims[3], kernel_size=5, stride=2, padding=2)
        self.conv5 = nn.Conv2d(latent_dims[3], latent_dims[4], kernel_size=7, stride=1, padding=0)
        
        self.fc_mean = nn.Linear(latent_dims[4] * 1 * 1, latent_dims[5])
        self.fc_log_var = nn.Linear(latent_dims[4] * 1 * 1, latent_dims[5])
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        
        x = x.view(x.size(0), -1)
        
        f_x = torch.tanh(self.fc_mean(x))
        log_var = self.fc_log_var(x)
        
        return f_x, log_var

In [5]:
class Decoder(nn.Module):
    def __init__(self, latent_dims, output_shape):
        super(Decoder, self).__init__()
        self.latent_dims = latent_dims
        self.output_shape = output_shape
        
        self.conv1 = nn.Conv2d(1, latent_dims[0], kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(latent_dims[0], latent_dims[1], kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(latent_dims[1], latent_dims[2], kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(latent_dims[2], latent_dims[3], kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(latent_dims[3], latent_dims[4], kernel_size=3, stride=1, padding=1)
        
        self.deconv1 = nn.ConvTranspose2d(latent_dims[4], latent_dims[3], kernel_size=3, stride=1, padding=1)
        self.deconv2 = nn.ConvTranspose2d(latent_dims[3], latent_dims[2], kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(latent_dims[2], latent_dims[1], kernel_size=3, stride=1, padding=1)
        self.deconv4 = nn.ConvTranspose2d(latent_dims[1], latent_dims[0], kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(latent_dims[0], output_shape[0], kernel_size=3, stride=1, padding=1)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        
        x = torch.relu(self.deconv1(x))
        x = torch.relu(self.deconv2(x))
        x = torch.relu(self.deconv3(x))
        x = torch.relu(self.deconv4(x))
        x = torch.sigmoid(self.deconv5(x))
        
        return x

In [6]:
class EchoModel(nn.Module):
    def __init__(self, input_shape, latent_dims, output_shape, T=1000, batch_size=100):
        super(EchoModel, self).__init__()
        self.input_shape = input_shape
        self.latent_dims = latent_dims
        self.output_shape = output_shape
        self.T = T
        self.batch_size = batch_size  # Add batch_size as an attribute of EchoModel
        
        self.encoder = Encoder(input_shape, latent_dims)
        self.decoder = Decoder(latent_dims, output_shape)
        
        # Define the noise schedule
        self.alpha = self.create_noise_schedule(T)
        
    def create_noise_schedule(self, T):
        alpha = torch.linspace(0.9999, 1e-5, T)
        return alpha
    
    def forward(self, x):
        f_x, log_var = self.encoder(x)
        
        # Convert log-variance to diagonal elements of S(x)
        diagonal_sx = torch.exp(log_var)
        
        # Create the full square matrix representation of S(x)
        sx_matrix = torch.diag_embed(diagonal_sx)

        print(f"Shape of fx: {f_x.shape}")
        print(f"Shape of Sx: {sx_matrix.shape}")
        # Generate the noise variable z using echo_sample
        z = echo_sample([f_x, sx_matrix], d_max = 5, batch_size=self.batch_size)  # Pass batch_size explicitly
        print(f"Shape of output: {z.shape}")
        # Perform the reconstruction process using Algorithm 2
        reconstructed_x = self.reconstruct(x, z, f_x, sx_matrix)
        
        return reconstructed_x, f_x, sx_matrix
    
    def reconstruct(self, x_t, z, f_x, sx_matrix):
        x_s = x_t
        for s in range(self.T-1, 0, -1):
            sqrt_alpha_s = torch.sqrt(self.alpha[s])
            sqrt_one_minus_alpha_s = torch.sqrt(1 - self.alpha[s])
            
            # Estimate the original image using the decoder
            x_0_hat = self.decoder(x_s)
            
            # Calculate the estimated noise using Eq. (3)
            z_hat = (x_s - sqrt_alpha_s * x_0_hat) / sqrt_one_minus_alpha_s
            
            # Calculate D(x_0_hat, s) and D(x_0_hat, s-1) using Eq. (5) and (6)
            D_x_0_hat_s = sqrt_alpha_s * x_0_hat + sqrt_one_minus_alpha_s * z_hat
            D_x_0_hat_s_minus_1 = torch.sqrt(self.alpha[s-1]) * x_0_hat + torch.sqrt(1 - self.alpha[s-1]) * z_hat
            
            # Update x_s using Eq. (7)
            x_s = x_s - D_x_0_hat_s + D_x_0_hat_s_minus_1
        
        return x_s

In [7]:
def train(model, optimizer, train_loader, device, num_epochs, loss_weights):
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            
            optimizer.zero_grad()
            
            reconstructed_x, f_x, sx_matrix = model(data)
            
            # Calculate the reconstruction loss (L1 or L2 norm)
            reconstruction_loss = nn.functional.l1_loss(reconstructed_x, data)
            
            # Calculate the mutual information penalty using echo_loss
            mi_penalty = echo_loss([f_x, sx_matrix])
            
            # Calculate the total loss as a weighted sum of reconstruction loss and MI penalty
            total_loss = loss_weights['reconstruction'] * reconstruction_loss + loss_weights['mi_penalty'] * mi_penalty
            
            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.item()
        
        # Print the average loss for the epoch
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    
    return model

# Define the input shape, latent dimensions, and output shape
input_shape = (1, 28, 28)  # Example shape for MNIST
latent_dims = [32, 32, 64, 64, 256, 32]  # Latent dimensions from echo.json
output_shape = (1, 28, 28)  # Example shape for MNIST

# Create an instance of the EchoModel
model = EchoModel(input_shape, latent_dims, output_shape).to(device)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Define the number of epochs and loss weights
num_epochs = 100
loss_weights = {'reconstruction': 1.0, 'mi_penalty': 0.0}  # Adjust the weights as needed

# Train the model
trained_model = train(model, optimizer, train_loader, device, num_epochs, loss_weights)

Shape of fx: torch.Size([100, 32])
Shape of Sx: torch.Size([100, 32, 32])
Shape of output: torch.Size([100, 32])


RuntimeError: MPS backend out of memory (MPS allocated: 45.90 GB, other allocations: 17.78 MB, max allowed: 45.90 GB). Tried to allocate 9.57 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).