In [None]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Lambda, Normalize, Grayscale, Resize

import matplotlib.pyplot as plt

import os
import pandas as pd
import numpy as np

from PIL import Image

In [None]:
%matplotlib inline

In [None]:
transforms = torchvision.transforms.Compose([
#     Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    Grayscale(),
    Resize((180, 180)),
    ToTensor(),
#     Lambda(lambda x: torch.squeeze(x, 0))
])

In [None]:
class BanglaMNISTDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        img_id = self.img_labels.iloc[idx, 1]
        img = Image.open(os.path.join(self.img_dir, img_id)) 
        label = self.img_labels.iloc[idx, 2]
        if self.transform:
            img = self.transform(img)
        return (img, label)

In [None]:
bmnist = BanglaMNISTDataset(annotations_file='../input/banglamnist/labels.csv',
                           img_dir='../input/banglamnist/bangla-mnist/labeled/',
                           transform=transforms)

In [None]:
dl = DataLoader(
    bmnist,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

In [None]:
def show_samples(data, targets):
    data = data.numpy()
    print("tensor shape: " + str(data.shape))
    
    fig = plt.figure()
    for i in range(9):
        plt.subplot(3,3,i+1)
        plt.tight_layout()

        img = data[i]
        img = np.moveaxis(img, 0, -1)
        
        plt.imshow(img, cmap='gray', interpolation='none')
        plt.title(f"{targets[i]}")
        
        plt.xticks([])
        plt.yticks([])

In [None]:
dataiter = enumerate(dl)
_, (sample_data, sample_targets) = next(dataiter)

show_samples(sample_data, sample_targets)