In [22]:
import os
import pickle
from glob import glob

import numpy as np
try:
    from matplotlib import pyplot as plt
    MATPLOTLIB = True
except:
    MATPLOTLIB = False

from torch import nn
from torch.utils.data import Dataset, DataLoader

In [20]:
class CifarDataSet(Dataset):
    def __init__(self, batch_dir_path, train=True):
        super().__init__()
        if not os.path.isdir(batch_dir_path):
            raise FolderDoesNotExist(batch_dir_path)
        all_batch_files = sorted(glob(f"{batch_dir_path}/data_batch_*"))
        if not all_batch_files:
            raise FileNotFoundError(f"No data_batch_n files found in {batch_dir_path}")

        lst_imgs, lst_labels = [], []
        if train:
            # return the 1st n-1 batches for training
            for fi_path in all_batch_files[:-1]:
                i, l = self.unpickle(fi_path)
                lst_imgs.append(i)
                lst_labels.extend(l)
            self.data = np.row_stack(lst_imgs).reshape((-1, 3, 32, 32))
            self.labels = lst_labels
        else:
            # return the nth batch for testing
            for fi_path in all_batch_files[-1:]:
                i, l = self.unpickle(fi_path)
                lst_imgs.append(i)
                lst_labels.extend(l)
            self.data = np.row_stack(lst_imgs).reshape((-1, 3, 32, 32))
            self.labels = lst_labels

    def unpickle(self, f):
        with open(f, "rb") as fo:
            dct = pickle.load(fo, encoding="bytes")
        # normalizing
        return dct.get(b"data") / 255.0, dct.get(b"labels")

    def get_image(self, index, plot=False):
        i = self.data[index]
        l = self.labels[index]
        i = np.transpose(i, axes=(1, 2, 0))
        if plot and MATPLOTLIB:
            plt.imshow(i)
        else:
            return i

    def __getitem__(self, index):
        return (
            self.data[index].astype(np.float32),
            self.labels[index],
        )

    def __len__(self):
        return self.data.shape[0]

In [21]:
train = CifarDataSet(
            batch_dir_path="/home/pranjal/pytorch/Datasets/cifar-10-batches-py/", train=True
        )
test = CifarDataSet(
    batch_dir_path="/home/pranjal/pytorch/Datasets/cifar-10-batches-py/", train=False
)

In [3]:
keep_prob = 0.7

In [133]:
encoder_b1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True),
        )
encoder_b2 = nn.Sequential(
            nn.BatchNorm2d(num_features=64),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True),
        )
encoder_b3 = nn.Sequential(
            nn.BatchNorm2d(num_features=128),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True),
        )
encoder_b4 = nn.Sequential(
            nn.BatchNorm2d(num_features=256),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True),
        )
encoder_b5 = nn.Sequential(
            nn.BatchNorm2d(num_features=512),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, padding=1),
        )


decoder_b1 = nn.Sequential(
        nn.ConvTranspose2d(1024, 512, 4, padding=1),
        nn.BatchNorm2d(num_features=512),
        )
decoder_b2 = nn.MaxUnpool2d(kernel_size=2, stride=2,)
decoder_b3 = nn.Sequential(
        nn.ReLU(),
        nn.ConvTranspose2d(512, 256, 3, padding=1),
        nn.BatchNorm2d(num_features=256),
        )
decoder_b4 = nn.MaxUnpool2d(kernel_size=2, stride=2,)
decoder_b5 = nn.Sequential(
        nn.ReLU(),
        nn.ConvTranspose2d(256, 128, 3, padding=1),
        nn.BatchNorm2d(num_features=128),
        )
decoder_b6 = nn.MaxUnpool2d(kernel_size=2, stride=2,)
decoder_b7 = nn.Sequential(
        nn.ReLU(),
        nn.ConvTranspose2d(128, 64, 3, padding=1),
        nn.BatchNorm2d(num_features=64),
        )
decoder_b8 = nn.MaxUnpool2d(kernel_size=2, stride=2,)
decoder_b9 = nn.Sequential(
        nn.ReLU(),
        nn.ConvTranspose2d(64, 3, 3, padding=1),
        )

In [134]:
tl = DataLoader(train, batch_size=128, shuffle=True)

In [135]:
for feat, lab in tl:
    op, i1 = encoder_b1(feat)
    op, i2 = encoder_b2(op)
    op, i3 = encoder_b3(op)
    op, i4 = encoder_b4(op)
    op = encoder_b5(op)

    op = decoder_b1(op)
    op = decoder_b2(op, i4)
    op = decoder_b3(op)
    op = decoder_b4(op, i3)
    op = decoder_b5(op)
    op = decoder_b6(op, i2)
    op = decoder_b7(op)
    op = decoder_b8(op, i1)
    op = decoder_b9(op)
    break