In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

class Autoencoder(torch.nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        # Encoder
        self.encoder = torch.nn.Sequential( # use max pooling
            torch.nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),  # 128 -> 64
            torch.nn.ReLU(),
            torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 64 -> 32
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 32 -> 16
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 16 -> 8
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), # 8 -> 4
        )

        # Fully connected layer
        self.layer = torch.nn.Sequential(
            torch.nn.Linear((128+4)*4*4, 128*4*4)
        )

        # Decoder
        self.decoder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1), # 4 -> 8
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 8 -> 16
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 16 -> 32
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # 32 -> 64
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1), # 64 -> 128
            torch.nn.Sigmoid()
        )


    def forward(self, image, action):
        x = self.encoder(image)
        
        action = torch.nn.functional.one_hot(action.to(torch.int64), num_classes=4).reshape(-1, 4, 1, 1).float()
        action = action.repeat(1, 1, x.shape[2], x.shape[3])
        x = torch.cat((x, action), 1)
        x = x.view(-1, (128+4)*x.shape[2]*x.shape[3])
        
        x = self.layer(x)
        x = x.view(-1, 128, 4, 4)
        
        x = self.decoder(x)

        return x

In [2]:
images = np.load("data/images.npy")
actions = np.load("data/actions.npy")
positions_after = np.load("data/positions_after.npy")
images_after = np.load("data/images_after.npy")

images = images / 255.0 # normalize images
images_after = images_after / 255.0 # normalize images

# prepare data
dataset = torch.utils.data.TensorDataset(torch.tensor(images, dtype=torch.float32), torch.tensor(actions, dtype=torch.float32), torch.tensor(positions_after, dtype=torch.float32), torch.tensor(images_after, dtype=torch.float32))
train_size = int(0.8*len(images))
valid_size = int(0.1*len(images))
test_size = len(images) - train_size - valid_size

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_images, valid_actions, valid_positions_after, valid_images_after = valid_dataset[:]
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [5]:


def train():
    reconstruction_model = Autoencoder().to(device)
    optimizer = torch.optim.Adam(reconstruction_model.parameters(), lr=0.0025)
    criterion = torch.nn.MSELoss()
    num_epochs = 2000

    training_losses = []
    validation_losses = []

    for epoch in range(num_epochs):
        # training
        running_loss = 0.0
        for i, (images, actions, positions_after, images_after) in enumerate(train_loader):
            images = images.to(device)
            actions = actions.to(device)
            positions_after = positions_after.to(device)
            images_after = images_after.to(device)

            optimizer.zero_grad()
            outputs = reconstruction_model(images, actions)
            loss = criterion(outputs, images_after)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()*images.size(0)

        epoch_loss = running_loss/len(train_loader.dataset)
        training_losses.append(epoch_loss)
        
        # validation
        running_loss = 0.0
        with torch.no_grad():
            outputs = reconstruction_model(valid_images.to(device), valid_actions.to(device))
            valid_loss = criterion(outputs, valid_images_after.to(device))
            validation_losses.append(valid_loss.item())
            print(epoch, epoch_loss)
            if validation_losses[-1] == min(validation_losses):
                print("Saving best model, epoch: ", epoch)
                torch.save(reconstruction_model.state_dict(), "hw1_3.pth")
    return training_losses

def test():
    reconstruction_model = Autoencoder()
    reconstruction_model.load_state_dict(torch.load("hw1_3.pth"))
    criterion = torch.nn.MSELoss()

    test_losses = []
    with torch.no_grad():
        for i, (images, actions, positions_after, images_after) in enumerate(test_loader):

            outputs = reconstruction_model(images, actions)
            loss = criterion(outputs, images_after)
            test_losses.append(loss.item())

            # visualize the result
            if i == 0:
                images = images.cpu().numpy()
                outputs = outputs.cpu().numpy()
                images_after = images_after.cpu().numpy()
                for j in range(5):
                    plt.subplot(3, 5, j+1)
                    plt.imshow(images[j].transpose(1, 2, 0))
                    if j == 0:
                        plt.ylabel("Input")
                    plt.yticks([])
                    plt.xticks([])
                    #plt.axis("off")
                    plt.subplot(3, 5, j+6)
                    plt.imshow(outputs[j].transpose(1, 2, 0))
                    if j == 0:
                        plt.ylabel("Output")
                    plt.yticks([])
                    plt.xticks([])
                    #plt.axis("off")
                    plt.subplot(3, 5, j+11)
                    plt.imshow(images_after[j].transpose(1, 2, 0))
                    if j == 0:
                        plt.ylabel("Ground Truth")
                    plt.yticks([])
                    plt.xticks([])
                    #plt.axis("off")
                plt.savefig("results/reconstruction_test_results.png")
                plt.close()
    return test_losses


def plot_loss(training_losses):
    plt.figure()
    plt.plot(training_losses)
    plt.grid(alpha=0.3)
    plt.yscale('log')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (Log Scale)')
    plt.title('Training Loss of Image Reconstruction Model')
    plt.savefig('results/reconstruction_training_loss.png')
    plt.close()


In [None]:
training_losses = train()
plot_loss(training_losses)

In [7]:
test_losses = test()

## write to file
with open("results/reconstruction_test_results.txt", "w") as f:
    f.write("Reconstruction test loss: " + str(test_losses) + "\n")