In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from os.path import join
from skimage import io, transform
import pandas as pd
import matplotlib.pyplot as plt 

from PIL import Image
from skimage import io, transform
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, utils

In [3]:
datadir = '/scratche/data/diabetic-retinopathy-detection/train'

In [10]:
df_labels = pd.read_csv('/scratche/data/diabetic-retinopathy-detection/trainLabels.csv')
df_labels['level'] = pd.to_numeric(df_labels['level'])
df_labels.groupby('level').count()
df_labels

Unnamed: 0_level_0,image
level,Unnamed: 1_level_1
0,25810
1,2443
2,5292
3,873
4,708


In [13]:
len(os.listdir('/scratche/data/diabetic-retinopathy-detection/train'))

35126

In [14]:
# Create dataset class

class DiabRetinopathyDataset(Dataset):
    def __init__(self, root_dir, csv_file, transform=None):
        self.root_dir = root_dir
        self.csv_file = csv_file
        self.transform = transform
        self.annotations = pd.read_csv(csv_file)
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_name = join(self.root_dir, self.annotations.iloc[idx]['image'])+'.jpeg'
        image = io.imread(img_name)
        image = Image.fromarray(image)
        label = self.annotations.iloc[idx]['level']
        
        sample = {'image': image, 'label': label}

        if self.transform:
            sample['image'] = self.transform(sample['image'])
            sample['image'] = np.array(sample['image'])

        return sample
    

In [6]:
# Instantiate dataset object
dataset = DiabRetinopathyDataset(root_dir='/scratche/data/diabetic-retinopathy-detection/train', 
                                 csv_file='/scratche/data/diabetic-retinopathy-detection/trainLabels.csv',
                                 transform=transforms.Compose([
                                     transforms.Resize(256),
                                     transforms.RandomCrop(224),
                                     transforms.ToTensor()
                                 ]))

dataloader = DataLoader(dataset, batch_size=32,
                        shuffle=True, num_workers=10)

In [None]:
# Plot dataset elements
fig, axs = plt.subplots(figsize=(10, 25), nrows=7, ncols=3)

for i, sample in enumerate(dataset):

    print(i, sample['image'].shape, sample['label'])

    ax = axs.flat[i]
    ax.set_title(f'Sample #{i}, Label - {sample["label"]}')
    ax.imshow(sample['image'])
    ax.axis('off')

    if i == 20:
        plt.show()
        break

In [7]:
from torchvision import models
vgg16 = models.vgg16(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/users/sansiddh/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [9]:
vgg16.classifier

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)