In [1]:
import torchvision
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns

In [4]:
class MSCTD(Dataset):
    def __init__(self, root, split, image_transform=None, text_transform=None, sentiment_transform=None):
        self.root = root
        self.split = split
        self.image_path = os.path.join(root, split)
        self.sentiment_path = os.path.join(root, split, 'sentiment_'+ split + '.txt')
        self.text_path = os.path.join(root, split, 'english_' + split + '.txt')
        self.image_index_path = os.path.join(root, split, 'image_index_' + split + '.txt')
        self.sentiment = []
        self.text = []
        self.image_index = []
        self.image = []
        self.image_transform = image_transform
        self.text_transform = text_transform
        self.sentiment_transform = sentiment_transform
        self.load_data()
        
    def load_data(self):
        with open(self.sentiment_path, 'r') as f:
            for line in f:
                self.sentiment.append(line.strip())
        with open(self.text_path, 'r') as f:
            for line in f:
                self.text.append(line.strip())
        with open(self.image_index_path, 'r') as f:
            for line in f:
                self.image_index.append(line.strip()[1:-1].split(', '))
        for root, _, files in os.walk(self.image_path):
            for file in files:
                if file.endswith('.jpg'):
                    self.image.append(Image.open(os.path.join(root, file)))
                    


    def __getitem__(self, index):
        image = self.image[index]
        sentiment = self.sentiment[index]
        text = self.text[index]
        image_index = self.image_index[index]
        if self.image_transform:
            image = self.image_transform(image)
        if self.text_transform:
            text = self.text_transform(text)
        if self.sentiment_transform:
            sentiment = self.sentiment_transform(sentiment)
        return image, sentiment, text, image_index

    def __len__(self):
        return len(self.image)
    

In [5]:
%pwd

'c:\\Users\\saeedzou\\Documents\\PycharmProjects\\Deep-Learning-Project'

In [6]:
MSCTD_test = MSCTD(root='data', split='test')

In [7]:
MSCTD_test[0]

(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x550>,
 '1',
 'With this asshole?',
 ['0', '1', '2', '3', '4', '5'])

In [8]:
MSCTD_test[1]

(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1280x553>,
 '1',
 'Two guys walk in, one walks out.',
 ['6',
  '7',
  '8',
  '9',
  '10',
  '11',
  '12',
  '13',
  '14',
  '15',
  '16',
  '17',
  '18',
  '19',
  '20',
  '21',
  '22',
  '23',
  '24',
  '25'])