In [None]:
import os
import torch
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import numpy as np

# variables
SEED = 711
DEVICE = "cuda:0"
NUM_EPOCHS = 100
BATCH_SIZE = 128
DATA_ROOT = "C:\\Users\\sultan.abughazal\\Documents\\Datasets\\ugvgpr-dataset"

In [None]:
# functions


In [None]:
# define the model

class ModelType(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(511, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 16),
            # nn.Tanh(),
            # nn.Linear(16, 4)
        )

        self.decoder = nn.Sequential(
            # nn.Linear(4, 16),
            # nn.Tanh(),
            nn.Linear(16, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 511),
            nn.Tanh()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [None]:
# define the dataset

class DatasetType(Dataset):
    def __init__(self, data_root, split):
        super().__init__()
        assert split in ["train", "eval"], "Invalid split!"
        self.samples = np.load(os.path.join(data_root, f"samples_{split}.npy"))
        self.split = split

    def __len__(self):
        return self.samples.shape[0]

    def __getitem__(self, index):
        sample = torch.from_numpy(self.samples[index, 1:])
        sample = sample.type(torch.float)

        sample -= sample.min()
        sample /= sample.max()
        sample = (sample * 2) - 1

        return sample, sample

In [None]:
# define a loss function
model = ModelType()
model = model.to(DEVICE)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.0001)

In [None]:
# train

np.random.seed(SEED)
torch.cuda.manual_seed(SEED)

ds = DatasetType(DATA_ROOT, "train")
data_loader = torch.utils.data.DataLoader(dataset=ds, batch_size=BATCH_SIZE, shuffle=True)

outputs = []
for epoch in range(NUM_EPOCHS):
    for input, target in data_loader:
        input = input.to(DEVICE)
        target = target.to(DEVICE)

        output = model(input)
        loss = criterion(output, input)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch:{epoch+1:0>3}, Loss:{loss.item():.8f}')

In [None]:
# evaluate

model.eval()
ds = DatasetType(DATA_ROOT, "eval")
# eval_dataloader = torch.utils.data.DataLoader(dataset=ds, batch_size=BATCH_SIZE, shuffle=True)
# random.choices(np.arange(len(ds)), k=3)

plt.figure(figsize=(25, 10))

rand_idx = random.choices(np.arange(len(ds)), k=1)[0]
input = ds[rand_idx][0].to(DEVICE)
output = model(input)

plt.plot(input.detach().cpu().numpy(), label="Input")
plt.plot(output.detach().cpu().numpy(), label="Reconstructed")
# for k in range(0, num_epochs, 4):
#     imgs = outputs[k][1].detach().numpy()
#     recon = outputs[k][2].detach().numpy()
#     for i, item in enumerate(imgs):
#         if i >= 9: break
#         plt.subplot(2, 9, i+1)
#         # item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
#         # item: 1, 28, 28
#         plt.imshow(item[0])

#     for i, item in enumerate(recon):
#         if i >= 9: break
#         plt.subplot(2, 9, 9+i+1) # row_length + i + 1
#         # item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
#         # item: 1, 28, 28
#         plt.imshow(item[0])