diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index 07592b64bae..09f69276a10 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -46,39 +46,12 @@ def __init__(self, root, transform=None, target_transform=None, download=True): self.data = [] self.labels = [] fp = os.path.join(root, self.filename) - file = open(fp, 'r') - data = file.read() - file.close() - dataSplitted = data.split("\n")[:-1] - datasetLength = len(dataSplitted) - i = 0 - while i < datasetLength: - # Get the 'i-th' row - strings = dataSplitted[i] - - # Split row into numbers(string), and avoid blank at the end - stringsSplitted = (strings[:-1]).split(" ") - - # Get data (which ends at column 256th), then in a numpy array. - rawData = stringsSplitted[:256] - dataFloat = [float(j) for j in rawData] - img = np.array(dataFloat[:16]) - j = 16 - k = 0 - while j < len(dataFloat): - temp = np.array(dataFloat[k:j]) - img = np.vstack((img, temp)) - - k = j - j += 16 - - self.data.append(img) - - # Get label and convert it into numbers, then in a numpy array. - labelString = stringsSplitted[256:] - labelInt = [int(index) for index in labelString] - self.labels.append(np.array(labelInt)) - i += 1 + data = np.loadtxt(fp) + # convert value to 8 bit unsigned integer + # color (white #255) the pixels + self.data = (data[:, :256] * 255).astype('uint8') + self.data = np.reshape(self.data, (-1, 16, 16)) + self.labels = np.nonzero(data[:, 256:])[1] def __getitem__(self, index): """ @@ -91,9 +64,6 @@ def __getitem__(self, index): # doing this so that it is consistent with all other datasets # to return a PIL Image - # convert value to 8 bit unsigned integer - # color (white #255) the pixels - img = img.astype('uint8') * 255 img = Image.fromarray(img, mode='L') if self.transform is not None: