In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import os
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
class AutoEncoder(nn.Module):
    """
    A Convolutional Autoencoder designed for 256x256x3 (RGB) image input.
    The architecture reduces the image to a latent space of 16x16x256.
    """
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True)
        )
        #(128, 4, 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )


    def forward(self, x):
        encoded = self.encoder(x)
        latent_vector = encoded 
        decoded = self.decoder(latent_vector)
        return decoded, latent_vector

In [None]:
data_root = './places365_data'

# Define the transformation pipeline
transform_32x32 = transforms.Compose([
    # Resize the image to 32x32 pixels
    #transforms.Resize(32),
    # You might also want a CenterCrop to ensure the final size is exactly 32x32
    # if you were resizing to a non-square ratio, but for 32, Resize(32) is often enough.
    # We keep ToTensor() last as it converts the PIL image to a PyTorch tensor
    transforms.ToTensor()
])

# Re-define your dataset call with the new transformation
data_root = './places365_data'

dataset = datasets.Places365(
    root=data_root,
    split='val',              # Split: 'train-standard', 'train-challenge', or 'val'
    small=True,               # Uses 256x256 images as the starting point
    download=True,            # Set to True to trigger the download
    # Apply the new transformation pipeline
    transform=transform_32x32
)

BATCH_SIZE = 8
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)


In [None]:
EPOCHS = 5
LEARNING_RATE = 1e-3

model = AutoEncoder().to(device)
criterion = nn.MSELoss() # Mean Squared Error is standard for reconstruction
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"\nModel initialized and sent to {device}.")
print(f"Total parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print("-" * 30)

# --- 4B. The Training Function ---
def train_model(model, dataloader, criterion, optimizer, num_epochs):
    model.train()
    training_losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch_idx, (data, _) in enumerate(dataloader):
            # Input data is the target (x, x) for the autoencoder
            data = data.to(device)
            
            # Forward pass
            reconstructed, _ = model(data)
            
            # Calculate loss: target is the original data
            loss = criterion(reconstructed, data)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item() * data.size(0)

            if batch_idx % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.6f}')

        avg_epoch_loss = epoch_loss / len(dataset)
        training_losses.append(avg_epoch_loss)
        print(f"--- Epoch {epoch+1} finished. Average Loss: {avg_epoch_loss:.6f} ---")
        
    return training_losses

training_losses = train_model(model, dataloader, criterion, optimizer, EPOCHS)

print("\nTraining complete.")


In [None]:
# --- 5A. Visualization Function ---
def visualize_results(model, dataloader, num_images=5):
    model.eval()
    data_iter = iter(dataloader)
    
    # Get a batch of images
    images, _ = next(data_iter)
    images = images[:num_images].to(device)
    
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        start.record()
        reconstructed, _ = model(images)
        end.record()
        torch.cuda.synchronize()
        print(start.elapsed_time(end))

    fig, axes = plt.subplots(2, num_images, figsize=(15, 6))
    
    for i in range(num_images):
        # Original Image
        # Move back to CPU, detach from graph, convert from (C, H, W) to (H, W, C)
        original_img = images[i].cpu().permute(1, 2, 0).numpy()
        axes[0, i].imshow(original_img)
        axes[0, i].set_title("Original")
        axes[0, i].axis('off')
        
        # Reconstructed Image
        reconstructed_img = reconstructed[i].cpu().permute(1, 2, 0).numpy()
        # Clamp to ensure image data is valid (0 to 1)
        reconstructed_img = reconstructed_img.clip(0, 1) 
        axes[1, i].imshow(reconstructed_img)
        axes[1, i].set_title("Reconstructed")
        axes[1, i].axis('off')

    plt.suptitle("Original vs. Reconstructed Images")
    plt.show()

# --- 5B. Run Visualization ---
# NOTE: Since we used dummy data and only 20 epochs, the reconstruction will be poor.
# This code is primarily for checking the model's functionality and visualization setup.
# If you run this with real data and sufficient training, the results will improve.
try:
    # Only try to visualize if the DataLoader is not completely empty
    if len(dataloader) > 0:
        visualize_results(model, dataloader, num_images=5)
    else:
        print("DataLoader is empty, skipping visualization.")
except Exception as e:
    print(f"Error during visualization: {e}")
