In [1]:
# https://datahacker.rs/019-siamese-network-in-pytorch-with-application-to-face-similarity
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn 
from torchvision.io import read_image, ImageReadMode
import torchvision.models as models 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
import random

In [2]:
RECORD_COUNT = 300 # 82736
BATCH_SIZE = 8
TRAIN_START = 0
TRAIN_END = 199
TEST_START = 200
TEST_END = 299
PAIR_COUNT = 32

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

cuda:0


In [4]:
xray_df = pd.read_csv(os.path.join("Data_Entry_2017.csv")) 
xray_df = xray_df.head(RECORD_COUNT)

In [5]:
class XRayNetworkDataset(torch.utils.data.Dataset):

    def __init__(self, dataframe, root_dir, start_index, end_index, transform=None, pair_count = PAIR_COUNT):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.dataframe = dataframe
        self.root_dir = root_dir
        self.transform = transform
        self.start_index = start_index
        self.end_index = end_index
        self.patient_map = {} 
        self.pair_count = pair_count
        self.patient_keys = []

        i = start_index
        while i <= end_index:
            patient_id = self.dataframe.iloc[i]['Patient ID']
            image_path = os.path.join(self.root_dir, dataframe.iloc[i]['Image Index'])
            if patient_id not in self.patient_map:
                self.patient_map[patient_id] = [image_path]
                self.patient_keys.append(patient_id)
            else:
                self.patient_map[patient_id].append(image_path)
            i += 1

    def __len__(self):
        return len(self.data_tuples)

    def __getitem__(self, idx):
        
        im0, im1 = None, None
        label = 0

        if (random.randint(0, 1) == 0):
            # Same patient 
            chosen_patient = random.choice(self.patient_keys)
            while len(self.patient_map[chosen_patient]) <= 1:
                chosen_patient = random.choice(self.patient_keys)
            im0, im1 = random.choices(self.patient_map[chosen_patient], k = 2)
            while im1 == im0:
                im0, im1 = random.choices(self.patient_map[chosen_patient], k = 2)
            label = torch.tensor(1)
            
        else:
            # Different patient
            p1, p2 = random.choices(self.patient_keys, k = 2)
            while p1 == p2:
                p1, p2 = random.choices(self.patient_keys, k = 2)
            im0 = random.choices(self.patient_map[p1], k = 1)
            im1 = random.choices(self.patient_map[p2], k = 1)
            label = torch.tensor(0)

        im0 = read_image(im0, mode=ImageReadMode.GRAY)
        im1 = read_image(im1, mode=ImageReadMode.GRAY)
        
        return (im0, im1, label)

In [8]:
dataset = XRayNetworkDataset(xray_df, os.path.join("images"), TRAIN_START, TRAIN_END)

In [9]:
for i in range(PAIR_COUNT):
    print (dataset[i])

RuntimeError: image::read_file() Expected a value of type 'str' for argument '_0' but instead found type 'list'.
Position: 0
Value: ['images\\00000035_000.png']
Declaration: image::read_file(str _0) -> (Tensor _0)
Cast error details: Unable to cast Python instance to C++ type (compile in debug mode for details)