# import site-package

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from torchvision.datasets import ImageFolder
import os

# Class Dataset

In [None]:
class CarRacingDataset(Dataset):
    def __init__(self, npy_file):
        self.data = np.load(npy_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        sample = torch.tensor(sample, dtype=torch.float32)
        return sample

# Class AE

### Encoder

In [None]:
class AEEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, latent_size*2)

    def forward(self, x):
        z = self.fc2(torch.relu(self.fc1(x)))
        return z

### Decoder

In [None]:
class AEDecoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid( self.fc2( torch.relu(self.fc1(x)) ) )
        return x

### AE Module

In [None]:
class AE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.encoder = AEEncoder(input_size, hidden_size, latent_size)
        self.decoder = AEDecoder(latent_size, hidden_size, input_size)

    def forward(self, x):
        z = self.encoder(x)
        reconstruct = self.decoder(z)
        return reconstruct, z

### loss function

In [None]:
def AE_Loss(reconstruct, x):
    return nn.MSELoss(reduction='sum')(reconstruct, x)

# Start Traning

### Load Dataset

In [None]:
dataset = CarRacingDataset("car_racing_states_2.npy")
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
for i in dataloader:
    print(i.shape)
    break

torch.Size([64, 72, 72])


### Initialize

In [None]:
input_size = 72 * 72
hidden_size = 800
latent_size = 50

In [None]:
batch_size = 64
lr = 1e-3
model_nums = 0
num_epochs = 10000 - model_nums
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
old_model = f'AE_Model/model_{model_nums}.pth'
model = AE(input_size, hidden_size, latent_size).to(device)

if model_nums > 0:
    model.load_state_dict(torch.load(f'AE_Model/model_{model_nums}.pth'), map_location=device)

optimizer = optim.Adam(model.parameters(), lr=lr)

### Train VAE

In [None]:
for epoch in range(model_nums, num_epochs):
    epoch_loss = 0.0
    for x in tqdm(data_loader):
        x = x.view(-1, input_size).to(device)
        reconstruct, mean, log_var = model(x)
        loss = AE_Loss(reconstruct, x, mean, log_var)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    if epoch % 10 == 9:
        torch.save(model.state_dict(), f'AE_Model/model_{epoch+1}.pth')
    print(f'epoch {epoch+1} loss: {epoch_loss / len(dataset):.4f}')

### Generate Picture

In [None]:
with torch.no_grad():
    z = torch.randn(1, latent_size)
    image = model.decoder(z).view(72, 72)
    image = image.detach().numpy()
    plt.imshow(image, cmap="gray")
    plt.show()

### Save Model

In [None]:
torch.save(model.state_dict(), f'AE_Model/model_final.pth')

# Test Model

In [None]:
model_nums = 10000

In [None]:
model.load_state_dict(torch.load(f'AE_Model/model_{model_nums}.pth'))
sample_image = dataset[0]
mean, log_var = model.encoder(sample_image.view(-1, input_size))
generated_image = model.decoder(mean).view(72,72)

plt.subplot(1, 2, 1)
plt.title('Original Image')
plt.imshow(sample_image.view(72, 72), cmap='gray')
plt.subplot(1, 2, 2)
plt.title('Generated Image')
plt.imshow(generated_image.detach().numpy(), cmap='gray')
plt.show()