Skip to content

Commit

Permalink
more on batch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yguan committed Aug 4, 2017
1 parent 2e23a54 commit 50766ae
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 24 deletions.
64 changes: 50 additions & 14 deletions mlbase/loaddata/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,17 @@ def getBatch(self, start, stop):
retIndex = 0
data = np.empty((stop-start, *self.shape[1:]))
for img in self.loadNextImage(start, stop):
if self.preprocessing is not None:
img = self.preprocessing.processImage(img)

ia = np.asarray(img)
iae = np.array(ia)
iae = np.rollaxis(iae, 2, 0)

data[retIndex, ...] = iae

retIndex += 1

print(self.meanMap)
if self.meanMap is not None:
data[:, 0, :, :] -= self.meanMap[0]
data[:, 1, :, :] -= self.meanMap[1]
Expand All @@ -118,6 +121,7 @@ def getName2IndexMap(self):
return self.name2index

# Subclass should implement this.
# expect a PIL.Image object.
def loadNextImage(self, start, stop):
pass

Expand All @@ -142,7 +146,6 @@ def __init__(self, tar_file, **kargs):
members = ctar.getmembers()

index = 0
print(cFile)
for member in members:
self.upFile[member.name] = cFile
self.upIndex[member.name] = index
Expand Down Expand Up @@ -172,19 +175,51 @@ def __init__(self, tar_file, **kargs):


imgName = self.index2name[0]
print(imgName)

indexList = []
currentFileName = imgName
while currentFileName != 'root':
print(currentFileName, self.upIndex[currentFileName], self.upFile[currentFileName])
currentIndex = self.upIndex[currentFileName]
indexList.append(currentIndex)
currentFileName = self.upFile[currentFileName]
print(indexList)

cfh = open(self.tarFile, 'rb')
while len(indexList) > 0:
cfh = tarfile.open(fileobj=cfh)
cindex = indexList.pop()
members = cfh.getmembers()
cfh = cfh.extractfile(members[cindex])
img0 = Image.open(cfh)
if kargs['preprocessing'] is not None:
img0 = kargs['preprocessing'].processImage(img0)
img0 = np.asarray(img0)
height = img0.shape[0]
width = img0.shape[1]
if img0.shape[2] != 3:
raise ValueError('Expect a RGB image.')
self.updateShape((batchSize, 3, height, width))


def loadNextImage(self, start, stop):
pass
index = start
while index < stop:
imgName = self.index2name[index]
indexList = []
currentFileName = imgName
while currentFileName != 'root':
currentIndex = self.upIndex[currentFileName]
indexList.append(currentIndex)
currentFileName = self.upFile[currentFileName]

cfh = open(self.tarFile, 'rb')
while len(indexList) > 0:
cfh = tarfile.open(fileobj=cfh)
cindex = indexList.pop()
members = cfh.getmembers()
cfh = cfh.extractfile(members[cindex])

img0 = Image.open(cfh)
yield img0
index += 1


class JPGinFolder(RGBImage):
Expand Down Expand Up @@ -217,13 +252,14 @@ def __init__(self, folder, **kargs):
# the numpy array for PIL image is in height-width order
# the PIL reported number is in width-height order
# theano expects height-width order
fh = Image.open(os.path.join(self.dirpath, self.index2name[0]))
(width, height) = fh.size
channel = None
if fh.mode == 'RGB':
channel = 3
else:
raise NotImplementedError('Expect a RGB image.')
img0 = Image.open(os.path.join(self.dirpath, self.index2name[0]))
if kargs['preprocessing'] is not None:
img0 = kargs['preprocessing'].processImage(img0)
img0 = np.asarray(img0)
height = img0.shape[0]
width = img0.shape[1]
if img0.shape[2] != 3:
raise ValueError('Expect a RGB image.')
self.updateShape((batchSize, 3, height, width))

def loadNextImage(self, start, stop):
Expand Down
17 changes: 14 additions & 3 deletions mlbase/loaddata/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import PIL.Image as pil_image
from PIL import Image
import os.path
import tarfile

Expand Down Expand Up @@ -88,7 +88,7 @@ def normalize2RGB(self, *args, _image=None, _step=None):
if _image.mode == "RGB":
_image.load()
elif _image.mode == "L":
_image = pil_image.merge("RGB", (_image, _image, _image))
_image = Image.merge("RGB", (_image, _image, _image))
_image.load()
elif _image.mode == "CMYK":
_image = _image.convert('RGB')
Expand All @@ -98,10 +98,21 @@ def normalize2RGB(self, *args, _image=None, _step=None):

return _image


def processImage(self, img):
if not isinstance(img, Image.Image):
raise ValueError('Expect PIL.Image')
return

for step in self.steps:
img = step['op'](..., _image=img, _step=step)
return img


def write2Tmp(self, output_path=None):

for (name, fh) in self._imageNameGenerator():
im = pil_image.open(fh)
im = Image.open(fh)
for step in self.steps:
im = step['op'](..., _image=im, _step=step)
print("{}, {}, {}, {}".format(name, im.format, im.size, im.mode))
Expand Down
16 changes: 9 additions & 7 deletions mlbase/tests/loaddata/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
def test_JPGinFolder(tmpdir):
p = tmpdir.mkdir('img')

width = 3
height = 2
width = 300
height = 128

img1 = np.random.randint(0, 256, size=(height, width, 3))
img1 = img1.astype(np.uint8)
Expand All @@ -18,17 +18,19 @@ def test_JPGinFolder(tmpdir):
img2 = Image.fromarray(img2.astype(np.uint8))
img2.save(os.path.join(str(p),'img2.jpg'))

bd = JPGinFolder(str(p))
ppsize = 100
prepro = ImageProcessor().scale2Shorter(ppsize).centerCrop((ppsize, ppsize)).normalize2RGB()
bd = JPGinFolder(str(p), channel_mean_map=[123, 117, 104], channel_order="BGR", preprocessing=prepro)

assert len(bd) == 2
assert bd.shape == (2, 3, height, width)
assert bd.shape == (2, 3, ppsize, ppsize)

i1 = bd[0]
i2 = bd[1]

assert i1.shape == (1, 3, height, width)
assert i2.shape == (1, 3, height, width)
assert i1.shape == (1, 3, ppsize, ppsize)
assert i2.shape == (1, 3, ppsize, ppsize)

data = bd[0:2]

assert data.shape == (2, 3, height, width)
assert data.shape == (2, 3, ppsize, ppsize)

0 comments on commit 50766ae

Please sign in to comment.