HI! This is an example dataset class that demonstrates the necessary functions you need to overwrite/implement. You won't actually be able to run this code because I'm going to put in a lot of disparate examples for the processing steps in case you ever have to work with different types of data. 

This is all still really rough, so if you have any questions please ask!

In [None]:
import torch
import torchvision
from torch.utils.data import Dataset
import pandas as pd


class ExampleData(Dataset):
    def __init__(self,file_name,mode=None):
        
        # This is an example of how to read in a csv (excel) file using the Pandas library
        ''' Get the data from the csv into a pandas dataframe '''
        data=pd.read_csv(file_name, header=0)
        
        # The next thing you'll wanna do is figure out how you're going to determine the train/test split amongst your data
        # they might be saved in different file locations so that simply calling this class with a different file_name will suffice
        # but if it's all in one file you'll need to separate it yourself
        # the code below makes a 80/20 split between train and test

        ''' Seperate the data into test and train '''
        if mode=='test':
            self.X=data.x[:int(.8*len(self.X))]
            self.Y=data.y[:int(.8*len(self.X))]
        if mode=='train':
            self.X=data.x[int(.8*len(self.X)):]
            self.Y=data.y[int(.8*len(self.X)):]

        # this will do some processing on the data you can either put that code here (if you have a small amount of data)
        # or you can implement your processing code in __getitem__ which is better if you have a lot of data
        self.X, self.Y = processData(self.X,self.Y)

    def __len__(self):
        # this function returns the length of your data
        return len(self.X)

    def __getitem__(self, i):
        # this function returns the ith [data, label] instance as tensors
        # the data you want to pass through your network is stored in self.X
        # the labels are stored in self.Y
        # the .view(#,#,#) changes the shape of the data and is optional to include in the following line 
        return torch.tensor(self.X[i]).view(1,len(self.X[i])), torch.tensor(self.Y[i]).view(1,len(self.Y[i]))


    def processData(self, dataX, dataY):
        # you do not need to implement this function in your class in order for it to work
        # I'm just putting the different data processing techniques here to keep the above functions cleaner

        # if you're working with just arrays of numerical data it's easiest to store it in a pandas dataframe 
        # and manipulate it in there OR you can use numpy just make sure you put it into a tensor 
        # HERE's a really in depth tutorial on how to use the package https://www.learndatasci.com/tutorials/python-pandas-tutorial-complete-introduction-for-beginners/


        # if you're using image data you can read in an image 
        image=cv2.imread(dataX.path)
        # if you wanna alter the images you can use some of the cv2 image transformation commands 
        # https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_geometric_transformations/py_geometric_transformations.html
        
        # or some of the transforms in the torchvision package (https://pytorch.org/docs/stable/torchvision/transforms.html) 
        # as long as you cast the image as a tensor first (or it might need to be a PIL image?)
        image=torch.tensor(image)

