In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from google.colab import drive
import os
import zipfile
from tqdm import tqdm

# Mount Google Drive
drive.mount('/content/drive')

# Create a directory for saving results
save_path = '/content/drive/My Drive/CIFAR10_reconstruction'
os.makedirs(save_path, exist_ok=True)

# Define the Encoder-Decoder Architecture
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load CIFAR-10 datasets
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                           download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                          download=True, transform=transform)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss function, and optimizer
model = AutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for data in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images, _ = data
        images = images.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, images)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

# Generate and save reconstructed images
model.eval()
with torch.no_grad():
    for i, (images, _) in enumerate(test_loader):
        images = images.to(device)
        reconstructed = model(images)

        # Save each image in batch
        for j in range(images.size(0)):
            idx = i * batch_size + j
            if idx >= len(test_dataset):
                break

            # Convert to numpy array and save
            recon_img = reconstructed[j].cpu().numpy()
            save_file = os.path.join(save_path, f'image_{idx:03d}.npy')
            np.save(save_file, recon_img)

# Create zip file of all reconstructed images
zip_path = os.path.join(save_path, 'reconstructed_images.zip')
with zipfile.ZipFile(zip_path, 'w') as zipf:
    for i in range(len(test_dataset)):
        file_name = f'image_{i:03d}.npy'
        file_path = os.path.join(save_path, file_name)
        zipf.write(file_path, file_name)

print("Processing complete! Files saved to Google Drive.")

# Calculate MSE for verification
def calculate_mse():
    total_mse = 0
    count = 0
    with torch.no_grad():
        for i, (images, _) in enumerate(test_loader):
            images = images.to(device)
            reconstructed = model(images)
            mse = criterion(reconstructed, images).item()
            total_mse += mse
            count += 1
    return total_mse / count

final_mse = calculate_mse()
print(f"Average MSE on test set: {final_mse:.4f}")


Mounted at /content/drive
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 39.5MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Epoch 1/50: 100%|██████████| 391/391 [01:17<00:00,  5.07it/s]


Epoch [1/50], Average Loss: 0.0194


Epoch 2/50: 100%|██████████| 391/391 [01:19<00:00,  4.91it/s]


Epoch [2/50], Average Loss: 0.0078


Epoch 3/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


Epoch [3/50], Average Loss: 0.0064


Epoch 4/50: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch [4/50], Average Loss: 0.0059


Epoch 5/50: 100%|██████████| 391/391 [01:17<00:00,  5.05it/s]


Epoch [5/50], Average Loss: 0.0056


Epoch 6/50: 100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


Epoch [6/50], Average Loss: 0.0053


Epoch 7/50: 100%|██████████| 391/391 [01:17<00:00,  5.03it/s]


Epoch [7/50], Average Loss: 0.0052


Epoch 8/50: 100%|██████████| 391/391 [01:17<00:00,  5.08it/s]


Epoch [8/50], Average Loss: 0.0050


Epoch 9/50: 100%|██████████| 391/391 [01:17<00:00,  5.05it/s]


Epoch [9/50], Average Loss: 0.0049


Epoch 10/50: 100%|██████████| 391/391 [01:15<00:00,  5.16it/s]


Epoch [10/50], Average Loss: 0.0046


Epoch 11/50: 100%|██████████| 391/391 [01:16<00:00,  5.14it/s]


Epoch [11/50], Average Loss: 0.0042


Epoch 12/50: 100%|██████████| 391/391 [01:17<00:00,  5.07it/s]


Epoch [12/50], Average Loss: 0.0041


Epoch 13/50: 100%|██████████| 391/391 [01:16<00:00,  5.14it/s]


Epoch [13/50], Average Loss: 0.0040


Epoch 14/50: 100%|██████████| 391/391 [01:15<00:00,  5.20it/s]


Epoch [14/50], Average Loss: 0.0039


Epoch 15/50: 100%|██████████| 391/391 [01:16<00:00,  5.12it/s]


Epoch [15/50], Average Loss: 0.0038


Epoch 16/50: 100%|██████████| 391/391 [01:21<00:00,  4.79it/s]


Epoch [16/50], Average Loss: 0.0038


Epoch 17/50: 100%|██████████| 391/391 [01:17<00:00,  5.07it/s]


Epoch [17/50], Average Loss: 0.0037


Epoch 18/50: 100%|██████████| 391/391 [01:16<00:00,  5.08it/s]


Epoch [18/50], Average Loss: 0.0036


Epoch 19/50: 100%|██████████| 391/391 [01:17<00:00,  5.05it/s]


Epoch [19/50], Average Loss: 0.0036


Epoch 20/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


Epoch [20/50], Average Loss: 0.0035


Epoch 21/50: 100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


Epoch [21/50], Average Loss: 0.0035


Epoch 22/50: 100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


Epoch [22/50], Average Loss: 0.0034


Epoch 23/50: 100%|██████████| 391/391 [01:17<00:00,  5.07it/s]


Epoch [23/50], Average Loss: 0.0033


Epoch 24/50: 100%|██████████| 391/391 [01:18<00:00,  4.95it/s]


Epoch [24/50], Average Loss: 0.0033


Epoch 25/50: 100%|██████████| 391/391 [01:16<00:00,  5.12it/s]


Epoch [25/50], Average Loss: 0.0032


Epoch 26/50: 100%|██████████| 391/391 [01:18<00:00,  5.01it/s]


Epoch [26/50], Average Loss: 0.0032


Epoch 27/50: 100%|██████████| 391/391 [01:18<00:00,  5.00it/s]


Epoch [27/50], Average Loss: 0.0031


Epoch 28/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s]


Epoch [28/50], Average Loss: 0.0031


Epoch 29/50: 100%|██████████| 391/391 [01:16<00:00,  5.09it/s]


Epoch [29/50], Average Loss: 0.0030


Epoch 30/50: 100%|██████████| 391/391 [01:15<00:00,  5.15it/s]


Epoch [30/50], Average Loss: 0.0030


Epoch 31/50: 100%|██████████| 391/391 [01:15<00:00,  5.16it/s]


Epoch [31/50], Average Loss: 0.0030


Epoch 32/50: 100%|██████████| 391/391 [01:16<00:00,  5.08it/s]


Epoch [32/50], Average Loss: 0.0029


Epoch 33/50: 100%|██████████| 391/391 [01:16<00:00,  5.13it/s]


Epoch [33/50], Average Loss: 0.0029


Epoch 34/50: 100%|██████████| 391/391 [01:17<00:00,  5.02it/s]


Epoch [34/50], Average Loss: 0.0029


Epoch 35/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


Epoch [35/50], Average Loss: 0.0028


Epoch 36/50: 100%|██████████| 391/391 [01:17<00:00,  5.07it/s]


Epoch [36/50], Average Loss: 0.0028


Epoch 37/50: 100%|██████████| 391/391 [01:18<00:00,  4.96it/s]


Epoch [37/50], Average Loss: 0.0028


Epoch 38/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


Epoch [38/50], Average Loss: 0.0028


Epoch 39/50: 100%|██████████| 391/391 [01:17<00:00,  5.04it/s]


Epoch [39/50], Average Loss: 0.0027


Epoch 40/50: 100%|██████████| 391/391 [01:15<00:00,  5.17it/s]


Epoch [40/50], Average Loss: 0.0027


Epoch 41/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


Epoch [41/50], Average Loss: 0.0027


Epoch 42/50: 100%|██████████| 391/391 [01:18<00:00,  4.99it/s]


Epoch [42/50], Average Loss: 0.0027


Epoch 43/50: 100%|██████████| 391/391 [01:16<00:00,  5.14it/s]


Epoch [43/50], Average Loss: 0.0027


Epoch 44/50: 100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


Epoch [44/50], Average Loss: 0.0026


Epoch 45/50: 100%|██████████| 391/391 [01:17<00:00,  5.05it/s]


Epoch [45/50], Average Loss: 0.0026


Epoch 46/50: 100%|██████████| 391/391 [01:16<00:00,  5.12it/s]


Epoch [46/50], Average Loss: 0.0026


Epoch 47/50: 100%|██████████| 391/391 [01:18<00:00,  4.98it/s]


Epoch [47/50], Average Loss: 0.0026


Epoch 48/50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


Epoch [48/50], Average Loss: 0.0026


Epoch 49/50: 100%|██████████| 391/391 [01:16<00:00,  5.10it/s]


Epoch [49/50], Average Loss: 0.0026


Epoch 50/50: 100%|██████████| 391/391 [01:17<00:00,  5.05it/s]


Epoch [50/50], Average Loss: 0.0025
Processing complete! Files saved to Google Drive.
Average MSE on test set: 0.0026
