In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [139]:
# import torch

# # Create a custom collate function
# def collate_fn(data):
#     # Initialize empty tensors for each resolution
#     resolution1_tensor = torch.empty(0)
#     resolution2_tensor = torch.empty(0)
#     resolution3_tensor = torch.empty(0)

#     # Traverse through each image in the data
#     for image in data: 
#         transformer = transforms.PILToTensor()
#         image_tensor = transformer(image[0])
#         # Check the resolution of the image
#         if (image_tensor.shape[1] == resolution1 & image_tensor.shape[2] == resolution1):
#             resolution1_tensor = torch.cat((resolution1_tensor, image_tensor), 0)
#         elif (image_tensor.shape[1] == resolution2 & image_tensor.shape[2] == resolution2):
#             resolution2_tensor = torch.cat((resolution2_tensor, image_tensor), 0)
#         elif (image_tensor.shape[1] == resolution3 & image_tensor.shape[2] == resolution3):
#             resolution3_tensor = torch.cat((resolution3_tensor, image_tensor), 0)
            

#     # Return the tensors for each resolution
#     return resolution1_tensor, resolution2_tensor, resolution3_tensor

# # Define the resolutions for the images
# resolution1 = 32
# resolution2 = 48
# resolution3 = 64

# # Create the dataloader
# train_dataset = datasets.ImageFolder(root='mnist-varres/train/')
# dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn)

In [140]:
# tensor1, tensor2, tensor3 = next(iter(dataloader))

In [141]:
# Create a custom collate function
def collate_fn(data):
    # Initialize empty tensors for each resolution
    resolution1_tensor = torch.empty(0)
    resolution2_tensor = torch.empty(0)
    resolution3_tensor = torch.empty(0)

    # Traverse through each image in the data
    for image in data: 
        transformer = transforms.ToPILImage()
        transformer1 = transforms.ToTensor()
        image_tensor = transformer1(transformer(image[0]))
        # Check the resolution of the image
        if (image_tensor.shape[1] == resolution1 & image_tensor.shape[2] == resolution1):
            resolution1_tensor = torch.cat((resolution1_tensor, image_tensor), 0)
        elif (image_tensor.shape[1] == resolution2 & image_tensor.shape[2] == resolution2):
            resolution2_tensor = torch.cat((resolution2_tensor, image_tensor), 0)
        elif (image_tensor.shape[1] == resolution3 & image_tensor.shape[2] == resolution3):
            resolution3_tensor = torch.cat((resolution3_tensor, image_tensor), 0)
            

    # Return the tensors for each resolution
    return resolution1_tensor, resolution2_tensor, resolution3_tensor

# Define the resolutions for the images
resolution1 = 32
resolution2 = 48
resolution3 = 64

In [142]:
def get_train_valid_loader(data_dir,
                           batch_size,
                           random_seed,
                           augment=False,
                           valid_size=0.2,
                           shuffle=True,
                           show_sample=False,
                           num_workers=1,
                           pin_memory=True):
    """
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to apply the data augmentation scheme
      mentioned in the paper. Only applied on the train split.
    - random_seed: fix seed for reproducibility.
    - valid_size: percentage split of the training set used for
      the validation set. Should be a float in the range [0, 1].
    - shuffle: whether to shuffle the train/validation indices.
    - show_sample: plot 9x9 sample grid of the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - train_loader: training set iterator.
    - valid_loader: validation set iterator.
    """
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

    normalize = transforms.Normalize((0.1307,), (0.3081,))  # MNIST

    # define transforms
    valid_transform = transforms.Compose([
#             transforms.Resize((28,28)),
            transforms.ToTensor(),
            normalize
        ])
    if augment:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
#             transforms.Resize((28,28)),
            transforms.ToTensor(),
            normalize
        ])
    else:
        train_transform = transforms.Compose([
#             transforms.Resize((28,28)),
            transforms.ToTensor(),
            normalize
        ])

    # load the dataset
    train_dataset = datasets.ImageFolder(root=data_dir,transform=train_transform)
    valid_dataset = datasets.ImageFolder(root=data_dir,transform=valid_transform)

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle == True:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(train_dataset, 
                    batch_size=batch_size, sampler=train_sampler, 
                    num_workers=num_workers, pin_memory=pin_memory,collate_fn=collate_fn)

    valid_loader = torch.utils.data.DataLoader(valid_dataset, 
                    batch_size=batch_size, sampler=valid_sampler, 
                    num_workers=num_workers, pin_memory=pin_memory,collate_fn=collate_fn)


    # visualize some images
    if show_sample:
        sample_loader = torch.utils.data.DataLoader(train_dataset, 
                                                    batch_size=9, 
                                                    shuffle=shuffle, 
                                                    num_workers=num_workers,
                                                    pin_memory=pin_memory,collate_fn=collate_fn)
        data_iter = iter(sample_loader)
        images, labels = data_iter.next()
        X = images.numpy()
        plot_images(X, labels)

    return (train_loader, valid_loader)

In [143]:
def get_test_loader(data_dir, 
                    batch_size,
                    shuffle=True,
                    num_workers=1,
                    pin_memory=True):
    """
    Utility function for loading and returning a multi-process 
    test iterator over the MNIST dataset.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - shuffle: whether to shuffle the dataset after every epoch.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - data_loader: test set iterator.
    """
    normalize = transforms.Normalize((0.1307,), (0.3081,))  # MNIST

    # define transform
    transform = transforms.Compose([
#         transforms.Resize((28,28)),
        transforms.ToTensor(),
        normalize
    ])

    dataset = datasets.ImageFolder(data_dir,transform=transform)

    data_loader = torch.utils.data.DataLoader(dataset, 
                                              batch_size=batch_size, 
                                              shuffle=shuffle, 
                                              num_workers=num_workers,
                                              pin_memory=pin_memory,collate_fn=collate_fn)

    return data_loader

In [144]:
train_loader, valid_loader = get_train_valid_loader('mnist-varres/train/..',
                           batch_size = 16,
                           random_seed = 100,
                           augment=False,
                           valid_size=1/6,
                           shuffle=True,
                           show_sample=False,
                           num_workers=0,
                           pin_memory=True)

In [145]:
test_loader = get_test_loader('mnist-varres/test/..', 
                    batch_size=10,
                    shuffle=True,
                    num_workers=0,
                    pin_memory=True)

In [146]:
loaders = {'train': train_loader,
          'test': test_loader,
          'val': valid_loader} 

In [147]:
tensor1, tensor2, tensor3 = next(iter(train_loader))

In [5]:
class CNN(nn.Module):
    def _init_(self, input_channels, num_classes, N):
        super(CNN, self)._init_()

        self.conv1 = nn.Conv2d(input_channels, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(2)

        self.conv3 = nn.Conv2d(32, N, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.maxpool3 = nn.MaxPool2d(2)

        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(N, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = self.conv3(x)
        x = self.relu3(x)
        x = self.maxpool3(x)

        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)

        return x

In [7]:
cnn = CNN(1,10,32)
cnn

TypeError: __init__() takes 1 positional argument but 4 were given