In [94]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.io import read_image
import pandas as pd
import os

In [95]:
class PosterDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, target_transform=None, genres=None):
        # csv_file: location of csv file
        # img_dir: location of image directory
        
        # transform: some transform object you can use to modify images. not necessary for project IMO
        # target_transform: something about transforming the label? not applicable to us
        
        # Genre: Lst parameter.
        
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        
        if genres != None:
            finalBool = 0
            
            # Add columns that tell you whether or not the movie is in the genre
            for genre in genres:
                newCol = self.df['Genre'].str.contains(genre, case=False)
                
                self.df[genre] = newCol
                
                if isinstance(finalBool, int):
                    finalBool = (self.df[genre] == True)
                else:
                    finalBool = finalBool | (self.df[genre] == True)
                    
            # Now the dataframe will only have movies of the genre you asked for.
            self.df = self.df[finalBool]
            self.df = self.df.reset_index(drop=True)
            
        #print(self.df.head())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.img_dir + '/' + str(self.df.iloc[idx]['imdbId']) + '.jpg'
        image = read_image(img_path)
        label = self.df.iloc[idx]['Score']
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [96]:
csv_file = '../data/movie_data.csv'
img_dir = '../data/MoviePosters'
pod = PosterDataset(csv_file, img_dir, genres = ['Action', 'Animation'])

In [97]:
pod.__getitem__(0)

(tensor([[[  3,   3,   4,  ...,   7,   0,   0],
          [  0,   0,   0,  ...,   6,   0,   0],
          [  0,   0,   0,  ...,  11,   5,   0],
          ...,
          [169, 174, 191,  ...,  93,  92,  92],
          [182, 184, 189,  ...,  69,  64,  59],
          [136, 132, 129,  ...,  58,  51,  47]],
 
         [[123, 121, 119,  ...,  98, 101, 103],
          [119, 118, 116,  ...,  88,  91,  93],
          [116, 114, 113,  ...,  79,  82,  84],
          ...,
          [ 97, 100, 113,  ...,  30,  31,  31],
          [101, 102, 102,  ...,  25,  25,  26],
          [ 56,  50,  44,  ...,  19,  19,  19]],
 
         [[184, 183, 182,  ..., 171, 171, 171],
          [180, 180, 177,  ..., 162, 162, 162],
          [177, 175, 174,  ..., 154, 154, 154],
          ...,
          [ 59,  55,  64,  ...,  13,  13,  13],
          [ 71,  64,  59,  ...,  22,  20,  17],
          [ 19,  12,   3,  ...,  22,  20,  18]]], dtype=torch.uint8),
 8.3)

In [98]:
len(pod)

6617

In [101]:
trainSize = int(len(pd) * 0.8)
testSize = len(pd) - trainSize

trainData, testData = torch.utils.data.random_split(pd, [trainSize, testSize])

In [102]:
trainDataLoader = DataLoader(trainData, batch_size = 2, shuffle=True)
testDataLoader = DataLoader(testData, batch_size = 2, shuffle=True)