In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from ResidualAutoEncoder import ResidualAutoencoder

In [5]:
def main():
    # Hyperparameters
    batch_size = 64
    learning_rate = 0.001
    epochs = 30
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Data transformation and loading
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), ( 0.5)),
        transforms.Resize((128, 128))
    ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Model, loss function, and optimizer
    model = ResidualAutoencoder(num_blocks=4, block_depth=2, bottleneck_dim=128, channels=3).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Train the model
    model.train_harness(model, train_loader, criterion, optimizer, device, epochs)
    
    # Evaluate the model
    inputs, outputs = model.evaluate_harness(model, test_loader, device)
    outputs = unnormalize(outputs)
    # Visualize the results
    num_images = 10
    fig, axes = plt.subplots(2, num_images, figsize=(15, 4))
    for i in range(num_images):
        axes[0, i].imshow(inputs[i].permute(1, 2, 0))
        axes[0, i].axis('off')
        axes[1, i].imshow(outputs[i].permute(1, 2, 0))
        axes[1, i].axis('off')
    
    plt.show()
    
def unnormalize(tensor):
    return tensor * 0.5 + 0.5

if __name__ == "__main__":
    main()

Files already downloaded and verified
Files already downloaded and verified
Epoch [1/30], Loss: 0.0720, learning rate: 0.001
Epoch [2/30], Loss: 0.0301, learning rate: 0.001
Epoch [3/30], Loss: 0.0278, learning rate: 0.001
Epoch [4/30], Loss: 0.0269, learning rate: 0.001
Epoch [5/30], Loss: 0.0259, learning rate: 0.001
Epoch [6/30], Loss: 0.0249, learning rate: 0.001
Epoch [7/30], Loss: 0.0246, learning rate: 0.001
Epoch [8/30], Loss: 0.0239, learning rate: 0.001
Epoch [9/30], Loss: 0.0237, learning rate: 0.001


: 