In [None]:
class TempDataset(Dataset):
    def __init__(self, mode = 'train', transform = True):
        self.mode = mode
        self.transforms = transform
        self._init_dataset()
        if transform:
            self._init_transform()

    def _init_dataset(self):
        self.files = []
        self.labels = []
        path = os.path.join("/content", "Images")

        for image in os.listdir(path):
            self.files.append(os.path.join(path, image))
            self.labels.append(0)

    def _init_transform(self):
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)), # VGGNet example
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __getitem__(self, index):
        image = self.files[index]
        label = self.labels[index]
        img = Image.open(image).convert('RGB')
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.files)


In [None]:
images = TempDataset()
loader = DataLoader(images, batch_size=1, num_workers=2)
outputs = []

with torch.no_grad():
    for i, (img, l) in enumerate(iter(loader)):
        img = img.to(device)
        l = l.to(device)
        output = model.forward(img)
        outputs.append(torch.max(output, 1)[1])

get_classlabels = {}
dirs = os.listdir(os.path.join("/content", "test"))
for i in range(len(dirs)):
    get_classlabels[i] = dirs[i]

print(get_classlabels)

In [None]:
import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
from random import randint
import numpy as np

fig = plot.figure(figsize=(30, 30))
outer = gridspec.GridSpec(5, 5, wspace=0.2, hspace=0.2)

for i in range(25):
    inner = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=outer[i], wspace=0.1, hspace=0.1)
    rnd_number = randint(0, len(testset))
    pred_image = testset[rnd_number][0].cuda().unsqueeze(0)
    pred_class = get_classlabels[testset[rnd_number][1]]
    pred_prob = model(pred_image).reshape(100)

    for j in range(2):
        if j % 2 == 0:
            ax = plot.Subplot(fig, inner[j])
            pred_image = (pred_image - torch.min(pred_image)) / (torch.max(pred_image) - torch.min(pred_image))
            ax.imshow(pred_image[0].permute(1, 2, 0).cpu().numpy())
            ax.set_title(pred_class)
            ax.set_xticks([])
            ax.set_yticks([])
            fig.add_subplot(ax)
        else:
            ax = plot.Subplot(fig, inner[j])
            ax.bar([i for i in range(100)], pred_prob.detach().cpu().numpy())
            fig.add_subplot(ax)

fig.show()
