In [4]:
import torch
from torch import nn

### Load a pretrained  encoder module

In [13]:
encoder = torch.load('encoder_model.pt')

### Encoder Network

In [56]:
class EncoderModel(nn.Module):
    def __init__(self):
        super(EncoderModel, self).__init__()
        #Defining CNN layers
        self.a = nn.BatchNorm2d(3)
        self.b = nn.Conv2d(in_channels = 3, out_channels=16, kernel_size=3, stride = 1, padding=0)
        self.c = nn.MaxPool2d(kernel_size= 2)
        self.d = nn.BatchNorm2d(16)
        self.e = nn.Conv2d(in_channels=16, out_channels= 4, kernel_size=3, stride=1, padding=0)
        self.f = nn.MaxPool2d(kernel_size = 2)
        
    def forward(self, X):
        #Encoding CNN
        load = self.a(X)
        load = self.b(load)
        load = self.c(load)
        load = self.d(load)
        load = self.e(load)
        load = self.f(load)
        return load

### Load the weight to the model

In [17]:
encoder_model = EncoderModel()
encoder_model.load_state_dict(encoder)
encoder_model.eval()

EncoderModel(
  (a): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (b): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (c): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (d): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (e): Conv2d(16, 4, kernel_size=(3, 3), stride=(1, 1))
  (f): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

### Creating custom dataset

In [21]:
from torch.utils.data import Dataset
from torchvision import transforms
to_tensor =  transforms.ToTensor()

In [32]:

import os
import PIL
import torch
import numpy as np
from torch.utils.data import Dataset

#Set seeds
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

class ImageData(Dataset):
    def __init__(self):
        self.overall_dataset_dir = os.path.join(os.path.join(os.getcwd()), 'image_data')
        self.all_filenames = os.listdir(self.overall_dataset_dir)
    
    def __len__(self):
        return len(self.all_filenames)
        
    def __getitem__(self, idx):
        selected_filename = self.all_filenames[idx]
        imagepil = PIL.Image.open(os.path.join(self.overall_dataset_dir,selected_filename)).convert('RGB')
        image = to_tensor(imagepil)
        label = self.all_filenames[idx].split(".")[0]
        sample = {'data':image,
                  'label':label,
                  'img_idx':idx}
        return sample

### Testing sample image

In [50]:
my_image = ImageData()
my_image_iter = iter(my_image)

In [54]:
from torch.utils.data import DataLoader
my_image_loader = DataLoader(my_image, batch_size = 1)

In [60]:
for count , img in enumerate(my_image_loader):
    data = encoder_model(img["data"])
    torch.save(data,'image_encoded_data/'+img['label'][0]+'.pt')

0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.Size([1, 4, 73, 73])
0 torch.Size([1, 4, 147, 147])
1 torch.S