Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 6 additions & 36 deletions torchvision/datasets/semeion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,39 +46,12 @@ def __init__(self, root, transform=None, target_transform=None, download=True):
self.data = []
self.labels = []
fp = os.path.join(root, self.filename)
file = open(fp, 'r')
data = file.read()
file.close()
dataSplitted = data.split("\n")[:-1]
datasetLength = len(dataSplitted)
i = 0
while i < datasetLength:
# Get the 'i-th' row
strings = dataSplitted[i]

# Split row into numbers(string), and avoid blank at the end
stringsSplitted = (strings[:-1]).split(" ")

# Get data (which ends at column 256th), then in a numpy array.
rawData = stringsSplitted[:256]
dataFloat = [float(j) for j in rawData]
img = np.array(dataFloat[:16])
j = 16
k = 0
while j < len(dataFloat):
temp = np.array(dataFloat[k:j])
img = np.vstack((img, temp))

k = j
j += 16

self.data.append(img)

# Get label and convert it into numbers, then in a numpy array.
labelString = stringsSplitted[256:]
labelInt = [int(index) for index in labelString]
self.labels.append(np.array(labelInt))
i += 1
data = np.loadtxt(fp)
# convert value to 8 bit unsigned integer
# color (white #255) the pixels
self.data = (data[:, :256] * 255).astype('uint8')
self.data = np.reshape(self.data, (-1, 16, 16))
self.labels = np.nonzero(data[:, 256:])[1]

def __getitem__(self, index):
"""
Expand All @@ -91,9 +64,6 @@ def __getitem__(self, index):

# doing this so that it is consistent with all other datasets
# to return a PIL Image
# convert value to 8 bit unsigned integer
# color (white #255) the pixels
img = img.astype('uint8') * 255
img = Image.fromarray(img, mode='L')

if self.transform is not None:
Expand Down