In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from ResidualAutoEncoder import ResidualAutoencoder
from ResidualAutoEncoder import Criterion

In [2]:
def main():
    # Hyperparameters
    batch_size = 64
    learning_rate = 0.001
    epochs = 5
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Data transformation and loading
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), ( 0.5)),
        transforms.Resize((128, 128))
    ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    

    
    # Model, loss function, and optimizer
    model = ResidualAutoencoder(num_blocks=3, block_depth=2, bottleneck_dim=128, channels=3, device=device)
    criterion = Criterion(model=model)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Train the model
    model.train_harness(model, train_loader, criterion, optimizer, epochs)
    
    torch.save(model.state_dict(), 'models/residual_autoencoder.pth')
    # Evaluate the model
    inputs, outputs = model.evaluate_harness(model, test_loader, device)
    outputs = unnormalize(outputs)
    # Visualize the results
    num_images = 10
    fig, axes = plt.subplots(2, num_images, figsize=(15, 4))
    for i in range(num_images):
        axes[0, i].imshow(inputs[i].permute(1, 2, 0))
        axes[0, i].axis('off')
        axes[1, i].imshow(outputs[i].permute(1, 2, 0))
        axes[1, i].axis('off')
    
    plt.show()
    
def unnormalize(tensor):
    return tensor * 0.5 + 0.5

if __name__ == "__main__":
    main()

Files already downloaded and verified
Files already downloaded and verified


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch [1/5], Loss: nan, recon_loss: 0.0001 learning rate: 0.001


In [None]:
model = ResidualAutoencoder(num_blocks=2, block_depth=2, bottleneck_dim=128, channels=3, asym_block=2, asym_depth=2)
torch.save(model.state_dict(), 'models/residual_autoencorder.pth')

3 64 2 3 2
64 128 2 3 2
