diff --git a/dataset.py b/dataset.py
index 411057e8..96ff0e91 100644
--- a/dataset.py
+++ b/dataset.py
@@ -7,10 +7,12 @@
 from scipy.misc import imread, imresize
 import numpy as np
 
+
 # Round x to the nearest multiple of p and x' >= x
 def round2nearest_multiple(x, p):
     return ((x - 1) // p + 1) * p
 
+
 class TrainDataset(torchdata.Dataset):
     def __init__(self, odgt, opt, max_sample=-1, batch_per_gpu=1):
         self.root_dataset = opt.root_dataset
@@ -32,7 +34,10 @@ def __init__(self, odgt, opt, max_sample=-1, batch_per_gpu=1):
         # mean and std
         self.img_transform = transforms.Compose([
             transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
-            ])
+        ])
+
+        # how many layers used to do predictions
+        self.nr_layers = 4
 
         self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
 
@@ -48,9 +53,9 @@ def _get_sub_batch(self):
             # get a sample record
             this_sample = self.list_sample[self.cur_idx]
             if this_sample['height'] > this_sample['width']:
-                self.batch_record_list[0].append(this_sample) # h > w, go to 1st class
+                self.batch_record_list[0].append(this_sample)  # h > w, go to 1st class
             else:
-                self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class
+                self.batch_record_list[1].append(this_sample)  # h <= w, go to 2nd class
 
             # update current sample pointer
             self.cur_idx += 1
@@ -88,8 +93,7 @@ def __getitem__(self, index):
         batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32)
         for i in range(self.batch_per_gpu):
             img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
-            this_scale = min(this_short_size / min(img_height, img_width), \
-                    self.imgMaxSize / max(img_height, img_width))
+            this_scale = min(this_short_size / min(img_height, img_width), self.imgMaxSize / max(img_height, img_width))
             img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale
             batch_resized_size[i, :] = img_resized_height, img_resized_width
         batch_resized_height = np.max(batch_resized_size[:, 0])
@@ -99,11 +103,12 @@ def __getitem__(self, index):
         batch_resized_height = int(round2nearest_multiple(batch_resized_height, self.padding_constant))
         batch_resized_width = int(round2nearest_multiple(batch_resized_width, self.padding_constant))
 
-        assert self.padding_constant >= self.segm_downsampling_rate,\
-                'padding constant must be equal or large than segm downsamping rate'
+        assert self.padding_constant >= self.segm_downsampling_rate, \
+            'padding constant must be equal or large than segm downsamping rate'
         batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width)
-        batch_segms = torch.zeros(self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \
-                                batch_resized_width // self.segm_downsampling_rate).long()
+        batch_segms = torch.zeros(self.nr_layers, self.batch_per_gpu,
+                                  batch_resized_height // self.segm_downsampling_rate,
+                                  batch_resized_width // self.segm_downsampling_rate).long()
 
         for i in range(self.batch_per_gpu):
             this_record = batch_records[i]
@@ -114,12 +119,12 @@ def __getitem__(self, index):
             img = imread(image_path, mode='RGB')
             segm = imread(segm_path)
 
-            assert(img.ndim == 3)
-            assert(segm.ndim == 2)
-            assert(img.shape[0] == segm.shape[0])
-            assert(img.shape[1] == segm.shape[1])
+            assert (img.ndim == 3)
+            assert (segm.ndim == 2)
+            assert (img.shape[0] == segm.shape[0])
+            assert (img.shape[1] == segm.shape[1])
 
-            if self.random_flip == True:
+            if self.random_flip:
                 random_flip = np.random.choice([0, 1])
                 if random_flip == 1:
                     img = cv2.flip(img, 1)
@@ -135,26 +140,40 @@ def __getitem__(self, index):
             segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8')
             segm_rounded[:segm.shape[0], :segm.shape[1]] = segm
 
-            segm = imresize(segm_rounded, (segm_rounded.shape[0] // self.segm_downsampling_rate, \
-                                           segm_rounded.shape[1] // self.segm_downsampling_rate), \
+            segm = imresize(segm_rounded, (segm_rounded.shape[0] // self.segm_downsampling_rate,
+                                           segm_rounded.shape[1] // self.segm_downsampling_rate),
                             interp='nearest')
-             # image to float
-            img = img.astype(np.float32)[:, :, ::-1] # RGB to BGR!!!
+
+            # construct ground-truth label map for each layer
+            standard_segm_h, standard_segm_w = segm.shape[0], segm.shape[1]
+            for id_layer in reversed(range(self.nr_layers)):
+                # downsampling first
+                this_segm = imresize(segm, (standard_segm_h // (2 ** id_layer), standard_segm_w // (2 ** id_layer)),
+                                     interp='nearest')
+                # upsampling the downsampled segm
+                this_segm_upsampled = imresize(this_segm, (standard_segm_h, standard_segm_w), interp='nearest')
+                # for those labels that are still correct, we predict them at this layer
+                this_segm_gt = this_segm_upsampled * (segm == this_segm_upsampled)
+                batch_segms[id_layer][i][:standard_segm_h, :standard_segm_w] = torch.from_numpy(this_segm_gt.astype(np.int)).long()
+                # remove already assigned labels (keep unassigned labels)
+                segm = segm * (this_segm_gt == 0)
+
+            # image to float
+            img = img.astype(np.float32)[:, :, ::-1]  # RGB to BGR!!!
             img = img.transpose((2, 0, 1))
             img = self.img_transform(torch.from_numpy(img.copy()))
 
             batch_images[i][:, :img.shape[1], :img.shape[2]] = img
-            batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long()
 
-        batch_segms = batch_segms - 1 # label from -1 to 149
+        batch_segms = batch_segms - 1  # label from -1 to 149
         output = dict()
         output['img_data'] = batch_images
         output['seg_label'] = batch_segms
         return output
 
     def __len__(self):
-        return int(1e6) # It's a fake length due to the trick that every loader maintains its own list
-        #return self.num_sampleclass
+        return int(1e6)  # It's a fake length due to the trick that every loader maintains its own list
+        # return self.num_sampleclass
 
 
 class ValDataset(torchdata.Dataset):
@@ -165,17 +184,20 @@ def __init__(self, odgt, opt, max_sample=-1, start_idx=-1, end_idx=-1):
         # max down sampling rate of network to avoid rounding during conv or pooling
         self.padding_constant = opt.padding_constant
 
+        # how many layers used to do predictions
+        self.nr_layers = 4
+
         # mean and std
         self.img_transform = transforms.Compose([
             transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
-            ])
+        ])
 
         self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
 
         if max_sample > 0:
             self.list_sample = self.list_sample[0:max_sample]
 
-        if start_idx >= 0 and end_idx >= 0: # divide file list
+        if start_idx >= 0 and end_idx >= 0:  # divide file list
             self.list_sample = self.list_sample[start_idx:end_idx]
 
         self.num_sample = len(self.list_sample)
@@ -188,16 +210,17 @@ def __getitem__(self, index):
         image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
         segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
         img = imread(image_path, mode='RGB')
-        img = img[:, :, ::-1] # BGR to RGB!!!
-        segm = imread(segm_path)
+        img = img[:, :, ::-1]  # RGB to BGR!!!
+        segm_ori = imread(segm_path)
 
         ori_height, ori_width, _ = img.shape
 
         img_resized_list = []
+        segm_gt_list = []
         for this_short_size in self.imgSize:
             # calculate target height and width
             scale = min(this_short_size / float(min(ori_height, ori_width)),
-                    self.imgMaxSize / float(max(ori_height, ori_width)))
+                        self.imgMaxSize / float(max(ori_height, ori_width)))
             target_height, target_width = int(ori_height * scale), int(ori_width * scale)
 
             # to avoid rounding in network
@@ -215,15 +238,31 @@ def __getitem__(self, index):
             img_resized = torch.unsqueeze(img_resized, 0)
             img_resized_list.append(img_resized)
 
-        segm = torch.from_numpy(segm.astype(np.int)).long()
-
-        batch_segms = torch.unsqueeze(segm, 0)
-
-        batch_segms = batch_segms - 1 # label from -1 to 149
+            # construct ground-truth label map for each layer
+            standard_segm_h, standard_segm_w = segm_ori.shape[0], segm_ori.shape[1]
+            segm = segm_ori.copy()
+            for id_layer in reversed(range(self.nr_layers)):
+                # downsampling first
+                this_segm = imresize(segm, (target_height // (2 ** (2+id_layer)), target_width // (2 ** (2+id_layer))),
+                                     interp='nearest')
+                # upsampling the downsampled segm
+                this_segm_upsampled = imresize(this_segm, (standard_segm_h, standard_segm_w), interp='nearest')
+                # for those labels that are still correct, we predict them at this layer
+                this_segm_gt = this_segm_upsampled * (segm == this_segm_upsampled)
+                segm_gt_list.append(torch.from_numpy(this_segm_gt.astype(np.int)).long()-1)
+                # remove already assigned labels (keep unassigned labels)
+                segm = segm * (this_segm_gt == 0)
+
+        segm_ori = torch.from_numpy(segm_ori.astype(np.int)).long()
+
+        batch_segms = torch.unsqueeze(segm_ori, 0)
+
+        batch_segms = batch_segms - 1  # label from -1 to 149
         output = dict()
         output['img_ori'] = img.copy()
         output['img_data'] = [x.contiguous() for x in img_resized_list]
         output['seg_label'] = batch_segms.contiguous()
+        output['seg_gt_list'] = segm_gt_list
         output['info'] = this_record['fpath_img']
         return output
 
@@ -243,7 +282,7 @@ def __init__(self, odgt, opt, max_sample=-1):
         # mean and std
         self.img_transform = transforms.Compose([
             transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
-            ])
+        ])
 
         if isinstance(odgt, list):
             self.list_sample = odgt
@@ -261,7 +300,7 @@ def __getitem__(self, index):
         # load image and label
         image_path = this_record['fpath_img']
         img = imread(image_path, mode='RGB')
-        img = img[:, :, ::-1] # BGR to RGB!!!
+        img = img[:, :, ::-1]  # BGR to RGB!!!
 
         ori_height, ori_width, _ = img.shape
 
@@ -269,7 +308,7 @@ def __getitem__(self, index):
         for this_short_size in self.imgSize:
             # calculate target height and width
             scale = min(this_short_size / float(min(ori_height, ori_width)),
-                    self.imgMaxSize / float(max(ori_height, ori_width)))
+                        self.imgMaxSize / float(max(ori_height, ori_width)))
             target_height, target_width = int(ori_height * scale), int(ori_width * scale)
 
             # to avoid rounding in network
diff --git a/eval_multipro.py b/eval_multipro.py
index afd72325..f2ff1e0b 100644
--- a/eval_multipro.py
+++ b/eval_multipro.py
@@ -39,6 +39,11 @@ def visualize_result(data, preds, args):
     cv2.imwrite(os.path.join(args.result,
                 img_name.replace('.jpg', '.png')), im_vis)
 
+def get_pred_map_via_gt(seg_gt_list, pred_list):
+    pred_map = torch.zeros_like(pred_list[0]).type_as(pred_list[0])
+    for this_seg_gt, this_pred in zip(seg_gt_list, pred_list):
+        pred_map += (this_seg_gt >= 0).long().unsqueeze(0) * this_pred
+    return pred_map
 
 def evaluate(segmentation_module, loader, args, dev_id, result_queue):
 
@@ -50,6 +55,7 @@ def evaluate(segmentation_module, loader, args, dev_id, result_queue):
         seg_label = as_numpy(batch_data['seg_label'][0])
 
         img_resized_list = batch_data['img_data']
+        seg_gt_list = batch_data['seg_gt_list']
 
         with torch.no_grad():
             segSize = (seg_label.shape[0], seg_label.shape[1])
@@ -60,11 +66,12 @@ def evaluate(segmentation_module, loader, args, dev_id, result_queue):
                 feed_dict['img_data'] = img
                 del feed_dict['img_ori']
                 del feed_dict['info']
+                del feed_dict['seg_gt_list']
                 feed_dict = async_copy_to(feed_dict, dev_id)
 
                 # forward pass
-                pred_tmp = segmentation_module(feed_dict, segSize=segSize)
-                pred = pred + pred_tmp.cpu() / len(args.imgSize)
+                pred_list = segmentation_module(feed_dict, segSize=segSize)
+                pred = pred + get_pred_map_via_gt(seg_gt_list, pred_list).cpu() / len(args.imgSize)
 
             _, preds = torch.max(pred.data.cpu(), dim=1)
             preds = as_numpy(preds.squeeze(0))
diff --git a/lib/nn/__init__.py b/lib/nn/__init__.py
index 98a96370..8203c02a 100644
--- a/lib/nn/__init__.py
+++ b/lib/nn/__init__.py
@@ -1,2 +1,3 @@
 from .modules import *
 from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
+from .prroi_pool import *
diff --git a/lib/nn/prroi_pool/.gitignore b/lib/nn/prroi_pool/.gitignore
new file mode 100644
index 00000000..18495ead
--- /dev/null
+++ b/lib/nn/prroi_pool/.gitignore
@@ -0,0 +1,2 @@
+*.o
+/_prroi_pooling
diff --git a/lib/nn/prroi_pool/__init__.py b/lib/nn/prroi_pool/__init__.py
new file mode 100644
index 00000000..0c40b7a7
--- /dev/null
+++ b/lib/nn/prroi_pool/__init__.py
@@ -0,0 +1,13 @@
+#! /usr/bin/env python3
+# -*- coding: utf-8 -*-
+# File   : __init__.py
+# Author : Jiayuan Mao, Tete Xiao
+# Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com
+# Date   : 07/13/2018
+# 
+# This file is part of PreciseRoIPooling.
+# Distributed under terms of the MIT license.
+# Copyright (c) 2017 Megvii Technology Limited.
+
+from .prroi_pool import *
+
diff --git a/lib/nn/prroi_pool/build.py b/lib/nn/prroi_pool/build.py
new file mode 100644
index 00000000..b1987908
--- /dev/null
+++ b/lib/nn/prroi_pool/build.py
@@ -0,0 +1,50 @@
+#! /usr/bin/env python3
+# -*- coding: utf-8 -*-
+# File   : build.py
+# Author : Jiayuan Mao, Tete Xiao
+# Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com
+# Date   : 07/13/2018
+# 
+# This file is part of PreciseRoIPooling.
+# Distributed under terms of the MIT license.
+# Copyright (c) 2017 Megvii Technology Limited.
+
+import os
+import torch
+
+from torch.utils.ffi import create_extension
+
+headers = []
+sources = []
+defines = []
+extra_objects = []
+with_cuda = False
+
+if torch.cuda.is_available():
+    with_cuda = True
+
+    headers+= ['src/prroi_pooling_gpu.h']
+    sources += ['src/prroi_pooling_gpu.c']
+    defines += [('WITH_CUDA', None)]
+
+    this_file = os.path.dirname(os.path.realpath(__file__))
+    extra_objects_cuda = ['src/prroi_pooling_gpu_impl.cu.o']
+    extra_objects_cuda = [os.path.join(this_file, fname) for fname in extra_objects_cuda]
+    extra_objects.extend(extra_objects_cuda)
+else:
+    # TODO(Jiayuan Mao @ 07/13): remove this restriction after we support the cpu implementation.
+    raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.')
+
+ffi = create_extension(
+    '_prroi_pooling',
+    headers=headers,
+    sources=sources,
+    define_macros=defines,
+    relative_to=__file__,
+    with_cuda=with_cuda,
+    extra_objects=extra_objects
+)
+
+if __name__ == '__main__':
+    ffi.build()
+
diff --git a/lib/nn/prroi_pool/functional.py b/lib/nn/prroi_pool/functional.py
new file mode 100644
index 00000000..036dbd8c
--- /dev/null
+++ b/lib/nn/prroi_pool/functional.py
@@ -0,0 +1,68 @@
+#! /usr/bin/env python3
+# -*- coding: utf-8 -*-
+# File   : functional.py
+# Author : Jiayuan Mao, Tete Xiao
+# Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com
+# Date   : 07/13/2018
+# 
+# This file is part of PreciseRoIPooling.
+# Distributed under terms of the MIT license.
+# Copyright (c) 2017 Megvii Technology Limited.
+
+import torch
+import torch.autograd as ag
+
+try:
+    from . import _prroi_pooling
+except ImportError:
+    raise ImportError('Can not found the compiled Precise RoI Pooling library. Run ./travis.sh in the directory first.')
+
+__all__ = ['prroi_pool2d']
+
+
+class PrRoIPool2DFunction(ag.Function):
+    @staticmethod
+    def forward(ctx, features, rois, pooled_height, pooled_width, spatial_scale):
+        features = features.contiguous()
+        rois = rois.contiguous()
+        pooled_height = int(pooled_height)
+        pooled_width = int(pooled_width)
+        spatial_scale = float(spatial_scale)
+
+        params = (pooled_height, pooled_width, spatial_scale)
+        batch_size, nr_channels, data_height, data_width = features.size()
+        nr_rois = rois.size(0)
+        output = torch.zeros(
+            (nr_rois, nr_channels, pooled_height, pooled_width),
+            dtype=features.dtype, device=features.device
+        )
+
+        if features.is_cuda:
+            _prroi_pooling.prroi_pooling_forward_cuda(features, rois, output, *params)
+            ctx.params = params
+            # everything here is contiguous.
+            ctx.save_for_backward(features, rois, output)
+        else:
+            raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.')
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        features, rois, output = ctx.saved_tensors
+        grad_input = grad_coor = None
+
+        if features.requires_grad:
+            grad_output = grad_output.contiguous()
+            grad_input = torch.zeros_like(features)
+            _prroi_pooling.prroi_pooling_backward_cuda(features, rois, output, grad_output, grad_input, *ctx.params)
+        if rois.requires_grad:
+            grad_output = grad_output.contiguous()
+            grad_coor = torch.zeros_like(rois)
+            _prroi_pooling.prroi_pooling_coor_backward_cuda(features, rois, output, grad_output, grad_coor, *ctx.params)
+
+        return grad_input, grad_coor, None, None, None
+
+
+prroi_pool2d = PrRoIPool2DFunction.apply
+
diff --git a/lib/nn/prroi_pool/prroi_pool.py b/lib/nn/prroi_pool/prroi_pool.py
new file mode 100644
index 00000000..998b2b80
--- /dev/null
+++ b/lib/nn/prroi_pool/prroi_pool.py
@@ -0,0 +1,28 @@
+#! /usr/bin/env python3
+# -*- coding: utf-8 -*-
+# File   : prroi_pool.py
+# Author : Jiayuan Mao, Tete Xiao
+# Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com
+# Date   : 07/13/2018
+# 
+# This file is part of PreciseRoIPooling.
+# Distributed under terms of the MIT license.
+# Copyright (c) 2017 Megvii Technology Limited.
+
+import torch.nn as nn
+
+from .functional import prroi_pool2d
+
+__all__ = ['PrRoIPool2D']
+
+
+class PrRoIPool2D(nn.Module):
+    def __init__(self, pooled_height, pooled_width, spatial_scale):
+        super().__init__()
+
+        self.pooled_height = int(pooled_height)
+        self.pooled_width = int(pooled_width)
+        self.spatial_scale = float(spatial_scale)
+
+    def forward(self, features, rois):
+        return prroi_pool2d(features, rois, self.pooled_height, self.pooled_width, self.spatial_scale)
diff --git a/lib/nn/prroi_pool/src/prroi_pooling_gpu.c b/lib/nn/prroi_pool/src/prroi_pooling_gpu.c
new file mode 100644
index 00000000..51520841
--- /dev/null
+++ b/lib/nn/prroi_pool/src/prroi_pooling_gpu.c
@@ -0,0 +1,96 @@
+/*
+ * File   : prroi_pooling_gpu.c
+ * Author : Jiayuan Mao, Tete Xiao
+ * Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com 
+ * Date   : 07/13/2018
+ * 
+ * Distributed under terms of the MIT license.
+ * Copyright (c) 2017 Megvii Technology Limited.
+ */
+
+#include <math.h>
+#include <THC/THC.h>
+
+#include "prroi_pooling_gpu_impl.cuh"
+
+extern THCState *state;
+
+int prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale) {
+    const float *data_ptr = THCudaTensor_data(state, features);
+    const float *rois_ptr = THCudaTensor_data(state, rois);
+    float *output_ptr = THCudaTensor_data(state, output);
+
+    int nr_rois = THCudaTensor_size(state, rois, 0);
+    int nr_channels = THCudaTensor_size(state, features, 1);
+    int height = THCudaTensor_size(state, features, 2);
+    int width = THCudaTensor_size(state, features, 3);
+    int top_count = nr_rois * nr_channels * pooled_height * pooled_width;
+
+    cudaStream_t stream = THCState_getCurrentStream(state);
+
+    PrRoIPoolingForwardGpu(
+        stream, data_ptr, rois_ptr, output_ptr,
+        nr_channels, height, width, pooled_height, pooled_width, spatial_scale,
+        top_count
+    );
+
+    return 1;
+}
+
+int prroi_pooling_backward_cuda(
+    THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff,
+    int pooled_height, int pooled_width, float spatial_scale) {
+
+    const float *data_ptr = THCudaTensor_data(state, features);
+    const float *rois_ptr = THCudaTensor_data(state, rois);
+    const float *output_ptr = THCudaTensor_data(state, output);
+    const float *output_diff_ptr = THCudaTensor_data(state, output_diff);
+    float *features_diff_ptr = THCudaTensor_data(state, features_diff);
+
+    int nr_rois = THCudaTensor_size(state, rois, 0);
+    int batch_size = THCudaTensor_size(state, features, 0);
+    int nr_channels = THCudaTensor_size(state, features, 1);
+    int height = THCudaTensor_size(state, features, 2);
+    int width = THCudaTensor_size(state, features, 3);
+    int top_count = nr_rois * nr_channels * pooled_height * pooled_width;
+    int bottom_count = batch_size * nr_channels * height * width;
+    
+    cudaStream_t stream = THCState_getCurrentStream(state);
+
+    PrRoIPoolingBackwardGpu(
+        stream, data_ptr, rois_ptr, output_ptr, output_diff_ptr, features_diff_ptr,
+        nr_channels, height, width, pooled_height, pooled_width, spatial_scale,
+        top_count, bottom_count
+    );
+
+    return 1;
+}
+
+int prroi_pooling_coor_backward_cuda(
+    THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *coor_diff,
+    int pooled_height, int pooled_width, float spatial_scale) {
+
+    const float *data_ptr = THCudaTensor_data(state, features);
+    const float *rois_ptr = THCudaTensor_data(state, rois);
+    const float *output_ptr = THCudaTensor_data(state, output);
+    const float *output_diff_ptr = THCudaTensor_data(state, output_diff);
+    float *coor_diff_ptr= THCudaTensor_data(state, coor_diff);
+
+    int nr_rois = THCudaTensor_size(state, rois, 0);
+    int nr_channels = THCudaTensor_size(state, features, 1);
+    int height = THCudaTensor_size(state, features, 2);
+    int width = THCudaTensor_size(state, features, 3);
+    int top_count = nr_rois * nr_channels * pooled_height * pooled_width;
+    int bottom_count = nr_rois * 5;
+
+    cudaStream_t stream = THCState_getCurrentStream(state);
+
+    PrRoIPoolingCoorBackwardGpu(
+        stream, data_ptr, rois_ptr, output_ptr, output_diff_ptr, coor_diff_ptr,
+        nr_channels, height, width, pooled_height, pooled_width, spatial_scale,
+        top_count, bottom_count
+    );
+
+    return 1;
+}
+
diff --git a/lib/nn/prroi_pool/src/prroi_pooling_gpu.h b/lib/nn/prroi_pool/src/prroi_pooling_gpu.h
new file mode 100644
index 00000000..bc9d3518
--- /dev/null
+++ b/lib/nn/prroi_pool/src/prroi_pooling_gpu.h
@@ -0,0 +1,22 @@
+/*
+ * File   : prroi_pooling_gpu.h
+ * Author : Jiayuan Mao, Tete Xiao
+ * Email  : maojiayuan@gmail.com, jasonhsiao97@gmail.com 
+ * Date   : 07/13/2018
+ * 
+ * Distributed under terms of the MIT license.
+ * Copyright (c) 2017 Megvii Technology Limited.
+ */
+
+int prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale);
+
+int prroi_pooling_backward_cuda(
+    THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff,
+    int pooled_height, int pooled_width, float spatial_scale
+);
+
+int prroi_pooling_coor_backward_cuda(
+    THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff,
+    int pooled_height, int pooled_width, float spatial_scal
+);
+
diff --git a/lib/nn/prroi_pool/src/prroi_pooling_gpu_impl.cu b/lib/nn/prroi_pool/src/prroi_pooling_gpu_impl.cu
new file mode 100644
index 00000000..8125c4d0
--- /dev/null
+++ b/lib/nn/prroi_pool/src/prroi_pooling_gpu_impl.cu
@@ -0,0 +1,443 @@
+/*
+ * File   : prroi_pooling_gpu_impl.cu
+ * Author : Tete Xiao, Jiayuan Mao
+ * Email  : jasonhsiao97@gmail.com 
+ * 
+ * Distributed under terms of the MIT license.
+ * Copyright (c) 2017 Megvii Technology Limited.
+ */
+
+#include "prroi_pooling_gpu_impl.cuh"
+
+#include <cstdio>
+#include <cfloat>
+
+#define CUDA_KERNEL_LOOP(i, n) \
+    for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+        i < (n); \
+        i += blockDim.x * gridDim.x)
+
+#define CUDA_POST_KERNEL_CHECK \
+    do { \
+        cudaError_t err = cudaGetLastError(); \
+        if (cudaSuccess != err) { \
+            fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); \
+            exit(-1); \
+        } \
+    } while(0)
+
+#define CUDA_NUM_THREADS 512
+
+namespace {
+
+static int CUDA_NUM_BLOCKS(const int N) {
+  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}
+
+__device__ static float PrRoIPoolingGetData(F_DEVPTR_IN data, const int h, const int w, const int height, const int width)
+{
+    bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
+    float retVal = overflow ? 0.0f : data[h * width + w];
+    return retVal;
+}
+
+__device__ static float PrRoIPoolingGetCoeff(float dh, float dw){
+    dw = dw > 0 ? dw : -dw;
+    dh = dh > 0 ? dh : -dh;
+    return (1.0f - dh) * (1.0f - dw);
+}
+
+__device__ static float PrRoIPoolingSingleCoorIntegral(float s, float t, float c1, float c2) {
+    return 0.5 * (t * t - s * s) * c2 + (t - 0.5 * t * t - s + 0.5 * s * s) * c1;
+}
+
+__device__ static float PrRoIPoolingInterpolation(F_DEVPTR_IN data, const float h, const float w, const int height, const int width){
+    float retVal = 0.0f;
+    int h1 = floorf(h);
+    int w1 = floorf(w);
+    retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));
+    h1 = floorf(h)+1;
+    w1 = floorf(w);
+    retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));
+    h1 = floorf(h);
+    w1 = floorf(w)+1;
+    retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));
+    h1 = floorf(h)+1;
+    w1 = floorf(w)+1;
+    retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));
+    return retVal;
+}
+
+__device__ static float PrRoIPoolingMatCalculation(F_DEVPTR_IN this_data, const int s_h, const int s_w, const int e_h, const int e_w,
+        const float y0, const float x0, const float y1, const float x1, const int h0, const int w0)
+{
+    float alpha, beta, lim_alpha, lim_beta, tmp;
+    float sum_out = 0;
+
+    alpha = x0 - float(s_w);
+    beta = y0 - float(s_h);
+    lim_alpha = x1 - float(s_w);
+    lim_beta = y1 - float(s_h);
+    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) 
+        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+    sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp;
+
+    alpha = float(e_w) - x1;
+    lim_alpha = float(e_w) - x0;
+    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) 
+        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+    sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp;
+
+    alpha = x0 - float(s_w);
+    beta = float(e_h) - y1;
+    lim_alpha = x1 - float(s_w);
+    lim_beta = float(e_h) - y0;
+    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) 
+        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+    sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp;
+
+    alpha = float(e_w) - x1;
+    lim_alpha = float(e_w) - x0;
+    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) 
+        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);   
+    sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp;
+
+    return sum_out;
+}
+
+__device__ static void PrRoIPoolingDistributeDiff(F_DEVPTR_OUT diff, const float top_diff, const int h, const int w, const int height, const int width, const float coeff)
+{
+    bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
+    if (!overflow) 
+        atomicAdd(diff + h * width + w, top_diff * coeff);
+}
+
+__device__ static void PrRoIPoolingMatDistributeDiff(F_DEVPTR_OUT diff, const float top_diff, const int s_h, const int s_w, const int e_h, const int e_w,
+        const float y0, const float x0, const float y1, const float x1, const int h0, const int w0)
+{
+    float alpha, beta, lim_alpha, lim_beta, tmp;
+
+    alpha = x0 - float(s_w);
+    beta = y0 - float(s_h);
+    lim_alpha = x1 - float(s_w);
+    lim_beta = y1 - float(s_h);
+    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) 
+        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+    PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp);
+
+    alpha = float(e_w) - x1;
+    lim_alpha = float(e_w) - x0;
+    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) 
+        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+    PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp);
+
+    alpha = x0 - float(s_w);
+    beta = float(e_h) - y1;
+    lim_alpha = x1 - float(s_w);
+    lim_beta = float(e_h) - y0;
+    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) 
+        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
+    PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp);
+
+    alpha = float(e_w) - x1;
+    lim_alpha = float(e_w) - x0;
+    tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha) 
+        * (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);   
+    PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp);
+}
+
+__global__ void PrRoIPoolingForward(
+        const int nthreads, 
+        F_DEVPTR_IN bottom_data,
+        F_DEVPTR_IN bottom_rois,
+        F_DEVPTR_OUT top_data,
+        const int channels, 
+        const int height,
+        const int width, 
+        const int pooled_height, 
+        const int pooled_width,
+        const float spatial_scale) {
+
+  CUDA_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    
+    bottom_rois += n * 5;
+    int roi_batch_ind = bottom_rois[0];
+    
+    float roi_start_w = bottom_rois[1] * spatial_scale;
+    float roi_start_h = bottom_rois[2] * spatial_scale;
+    float roi_end_w = bottom_rois[3] * spatial_scale;
+    float roi_end_h = bottom_rois[4] * spatial_scale;
+
+    float roi_width = max(roi_end_w - roi_start_w, ((float)0.0));
+    float roi_height = max(roi_end_h - roi_start_h, ((float)0.0));
+    float bin_size_h = roi_height / static_cast<float>(pooled_height);
+    float bin_size_w = roi_width / static_cast<float>(pooled_width);
+
+    const float *this_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
+    float *this_out = top_data + index;
+
+    float win_start_w = roi_start_w + bin_size_w * pw;
+    float win_start_h = roi_start_h + bin_size_h * ph;
+    float win_end_w = win_start_w + bin_size_w;
+    float win_end_h = win_start_h + bin_size_h;
+    
+    float win_size = max(float(0.0), bin_size_w * bin_size_h);
+    if (win_size == 0) {
+        *this_out = 0;
+        return;
+    }
+
+    float sum_out = 0;
+
+    int s_w, s_h, e_w, e_h;
+    
+    s_w = floorf(win_start_w);
+    e_w = ceilf(win_end_w);
+    s_h = floorf(win_start_h);
+    e_h = ceilf(win_end_h);
+
+    for (int w_iter = s_w; w_iter < e_w; ++w_iter)
+        for (int h_iter = s_h; h_iter < e_h; ++h_iter)
+            sum_out += PrRoIPoolingMatCalculation(this_data, h_iter, w_iter, h_iter + 1, w_iter + 1, 
+                max(win_start_h, float(h_iter)), max(win_start_w, float(w_iter)),
+                min(win_end_h, float(h_iter) + 1.0), min(win_end_w, float(w_iter + 1.0)),
+                height, width);
+    *this_out = sum_out / win_size; 
+  }
+}
+
+__global__ void PrRoIPoolingBackward(
+        const int nthreads, 
+        F_DEVPTR_IN bottom_rois,
+        F_DEVPTR_IN top_diff,
+        F_DEVPTR_OUT bottom_diff,
+        const int channels, 
+        const int height, 
+        const int width,
+        const int pooled_height, 
+        const int pooled_width,
+        const float spatial_scale) {
+
+  CUDA_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    bottom_rois += n * 5; 
+    
+    int roi_batch_ind = bottom_rois[0];
+    float roi_start_w = bottom_rois[1] * spatial_scale;
+    float roi_start_h = bottom_rois[2] * spatial_scale;
+    float roi_end_w = bottom_rois[3] * spatial_scale;
+    float roi_end_h = bottom_rois[4] * spatial_scale;
+    
+    float roi_width = max(roi_end_w - roi_start_w, (float)0);
+    float roi_height = max(roi_end_h - roi_start_h, (float)0);
+    float bin_size_h = roi_height / static_cast<float>(pooled_height);
+    float bin_size_w = roi_width / static_cast<float>(pooled_width);
+
+    const float *this_out_grad = top_diff + index;
+    float *this_data_grad = bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+    float win_start_w = roi_start_w + bin_size_w * pw;
+    float win_start_h = roi_start_h + bin_size_h * ph;
+    float win_end_w = win_start_w + bin_size_w;
+    float win_end_h = win_start_h + bin_size_h;
+
+    float win_size = max(float(0.0), bin_size_w * bin_size_h);
+
+    float sum_out = win_size == float(0) ? float(0) : *this_out_grad / win_size;
+
+    int s_w, s_h, e_w, e_h;
+
+    s_w = floorf(win_start_w);
+    e_w = ceilf(win_end_w);
+    s_h = floorf(win_start_h);
+    e_h = ceilf(win_end_h);
+
+    for (int w_iter = s_w; w_iter < e_w; ++w_iter)
+        for (int h_iter = s_h; h_iter < e_h; ++h_iter)
+            PrRoIPoolingMatDistributeDiff(this_data_grad, sum_out, h_iter, w_iter, h_iter + 1, w_iter + 1, 
+                max(win_start_h, float(h_iter)), max(win_start_w, float(w_iter)),
+                min(win_end_h, float(h_iter) + 1.0), min(win_end_w, float(w_iter + 1.0)),
+                height, width);
+
+  }
+}
+
+__global__ void PrRoIPoolingCoorBackward(
+        const int nthreads, 
+        F_DEVPTR_IN bottom_data,
+        F_DEVPTR_IN bottom_rois,
+        F_DEVPTR_IN top_data, 
+        F_DEVPTR_IN top_diff,
+        F_DEVPTR_OUT bottom_diff, 
+        const int channels, 
+        const int height, 
+        const int width,
+        const int pooled_height, 
+        const int pooled_width,
+        const float spatial_scale) {
+
+  CUDA_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+    bottom_rois += n * 5;
+
+    int roi_batch_ind = bottom_rois[0];
+    float roi_start_w = bottom_rois[1] * spatial_scale;
+    float roi_start_h = bottom_rois[2] * spatial_scale;
+    float roi_end_w = bottom_rois[3] * spatial_scale;
+    float roi_end_h = bottom_rois[4] * spatial_scale;
+
+    float roi_width = max(roi_end_w - roi_start_w, (float)0);
+    float roi_height = max(roi_end_h - roi_start_h, (float)0);
+    float bin_size_h = roi_height / static_cast<float>(pooled_height);
+    float bin_size_w = roi_width / static_cast<float>(pooled_width);
+
+    const float *this_out_grad = top_diff + index;
+    const float *this_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
+    const float *this_top_data = top_data + index;
+    float *this_data_grad = bottom_diff + n * 5;
+    
+    float win_start_w = roi_start_w + bin_size_w * pw;
+    float win_start_h = roi_start_h + bin_size_h * ph;
+    float win_end_w = win_start_w + bin_size_w;
+    float win_end_h = win_start_h + bin_size_h;
+
+    float win_size = max(float(0.0), bin_size_w * bin_size_h);
+
+    float sum_out = win_size == float(0) ? float(0) : *this_out_grad / win_size;
+    
+    // WARNING: to be discussed
+    if (sum_out == 0)
+        return;
+
+    int s_w, s_h, e_w, e_h;
+
+    s_w = floorf(win_start_w);
+    e_w = ceilf(win_end_w);
+    s_h = floorf(win_start_h);
+    e_h = ceilf(win_end_h);
+
+    float g_x1_y = 0, g_x2_y = 0, g_x_y1 = 0, g_x_y2 = 0;
+    for (int h_iter = s_h; h_iter < e_h; ++h_iter) {
+        g_x1_y += PrRoIPoolingSingleCoorIntegral(max(win_start_h, float(h_iter)) - h_iter, 
+                min(win_end_h, float(h_iter + 1)) - h_iter, 
+                PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_start_w, height, width),
+                PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_start_w, height, width));
+    
+        g_x2_y += PrRoIPoolingSingleCoorIntegral(max(win_start_h, float(h_iter)) - h_iter, 
+                min(win_end_h, float(h_iter + 1)) - h_iter, 
+                PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_end_w, height, width),
+                PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_end_w, height, width));
+    }
+
+    for (int w_iter = s_w; w_iter < e_w; ++w_iter) {
+        g_x_y1 += PrRoIPoolingSingleCoorIntegral(max(win_start_w, float(w_iter)) - w_iter, 
+                min(win_end_w, float(w_iter + 1)) - w_iter, 
+                PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter, height, width),
+                PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter + 1, height, width));
+    
+        g_x_y2 += PrRoIPoolingSingleCoorIntegral(max(win_start_w, float(w_iter)) - w_iter, 
+                min(win_end_w, float(w_iter + 1)) - w_iter, 
+                PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter, height, width),
+                PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter + 1, height, width));
+    }
+
+    float partial_x1 = -g_x1_y + (win_end_h - win_start_h) * (*this_top_data);
+    float partial_y1 = -g_x_y1 + (win_end_w - win_start_w) * (*this_top_data);
+    float partial_x2 = g_x2_y - (win_end_h - win_start_h) * (*this_top_data);
+    float partial_y2 = g_x_y2 - (win_end_w - win_start_w) * (*this_top_data);
+
+    partial_x1 = partial_x1 / win_size * spatial_scale;
+    partial_x2 = partial_x2 / win_size * spatial_scale;
+    partial_y1 = partial_y1 / win_size * spatial_scale;
+    partial_y2 = partial_y2 / win_size * spatial_scale;
+    
+    // (b, x1, y1, x2, y2)
+    
+    this_data_grad[0] = 0;
+    atomicAdd(this_data_grad + 1, (partial_x1 * (1.0 - float(pw) / pooled_width) + partial_x2 * (1.0 - float(pw + 1) / pooled_width)) 
+            * (*this_out_grad));
+    atomicAdd(this_data_grad + 2, (partial_y1 * (1.0 - float(ph) / pooled_height) + partial_y2 * (1.0 - float(ph + 1) / pooled_height))
+            * (*this_out_grad));
+    atomicAdd(this_data_grad + 3, (partial_x2 * float(pw + 1) / pooled_width + partial_x1 * float(pw) / pooled_width)
+            * (*this_out_grad)); 
+    atomicAdd(this_data_grad + 4, (partial_y2 * float(ph + 1) / pooled_height + partial_y1 * float(ph) / pooled_height)
+            * (*this_out_grad)); 
+  }
+}
+
+} /* !anonymous namespace */
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void PrRoIPoolingForwardGpu(
+    cudaStream_t stream,
+    F_DEVPTR_IN bottom_data,
+    F_DEVPTR_IN bottom_rois,
+    F_DEVPTR_OUT top_data,
+    const int channels_, const int height_, const int width_, 
+    const int pooled_height_, const int pooled_width_,
+    const float spatial_scale_,
+    const int top_count) {
+
+    PrRoIPoolingForward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(
+        top_count, bottom_data, bottom_rois, top_data,
+        channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);
+
+    CUDA_POST_KERNEL_CHECK;
+}
+
+void PrRoIPoolingBackwardGpu(
+    cudaStream_t stream,
+    F_DEVPTR_IN bottom_data,
+    F_DEVPTR_IN bottom_rois,
+    F_DEVPTR_IN top_data,
+    F_DEVPTR_IN top_diff,
+    F_DEVPTR_OUT bottom_diff,
+    const int channels_, const int height_, const int width_, 
+    const int pooled_height_, const int pooled_width_, 
+    const float spatial_scale_,
+    const int top_count, const int bottom_count) {
+
+    cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_count, stream);
+    PrRoIPoolingBackward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(
+        top_count, bottom_rois, top_diff, bottom_diff,
+        channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);
+    CUDA_POST_KERNEL_CHECK;
+}
+
+void PrRoIPoolingCoorBackwardGpu(
+    cudaStream_t stream,
+    F_DEVPTR_IN bottom_data,
+    F_DEVPTR_IN bottom_rois,
+    F_DEVPTR_IN top_data,
+    F_DEVPTR_IN top_diff,
+    F_DEVPTR_OUT bottom_diff,
+    const int channels_, const int height_, const int width_, 
+    const int pooled_height_, const int pooled_width_, 
+    const float spatial_scale_,
+    const int top_count, const int bottom_count) {
+
+    cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_count, stream);
+    PrRoIPoolingCoorBackward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(
+        top_count, bottom_data, bottom_rois, top_data, top_diff, bottom_diff,
+        channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);
+    CUDA_POST_KERNEL_CHECK;
+}
+
+} /* !extern "C" */
+
diff --git a/lib/nn/prroi_pool/src/prroi_pooling_gpu_impl.cuh b/lib/nn/prroi_pool/src/prroi_pooling_gpu_impl.cuh
new file mode 100644
index 00000000..4125086f
--- /dev/null
+++ b/lib/nn/prroi_pool/src/prroi_pooling_gpu_impl.cuh
@@ -0,0 +1,59 @@
+/*
+ * File   : prroi_pooling_gpu_impl.cuh
+ * Author : Tete Xiao, Jiayuan Mao
+ * Email  : jasonhsiao97@gmail.com 
+ * 
+ * Distributed under terms of the MIT license.
+ * Copyright (c) 2017 Megvii Technology Limited.
+ */
+
+#ifndef PRROI_POOLING_GPU_IMPL_CUH
+#define PRROI_POOLING_GPU_IMPL_CUH
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define F_DEVPTR_IN const float * 
+#define F_DEVPTR_OUT float * 
+
+void PrRoIPoolingForwardGpu(
+    cudaStream_t stream,
+    F_DEVPTR_IN bottom_data,
+    F_DEVPTR_IN bottom_rois,
+    F_DEVPTR_OUT top_data,
+    const int channels_, const int height_, const int width_, 
+    const int pooled_height_, const int pooled_width_,
+    const float spatial_scale_,
+    const int top_count);
+
+void PrRoIPoolingBackwardGpu(
+    cudaStream_t stream,
+    F_DEVPTR_IN bottom_data,
+    F_DEVPTR_IN bottom_rois,
+    F_DEVPTR_IN top_data,
+    F_DEVPTR_IN top_diff,
+    F_DEVPTR_OUT bottom_diff,
+    const int channels_, const int height_, const int width_, 
+    const int pooled_height_, const int pooled_width_, 
+    const float spatial_scale_,
+    const int top_count, const int bottom_count);
+
+void PrRoIPoolingCoorBackwardGpu(
+    cudaStream_t stream,
+    F_DEVPTR_IN bottom_data,
+    F_DEVPTR_IN bottom_rois,
+    F_DEVPTR_IN top_data,
+    F_DEVPTR_IN top_diff,
+    F_DEVPTR_OUT bottom_diff,
+    const int channels_, const int height_, const int width_, 
+    const int pooled_height_, const int pooled_width_, 
+    const float spatial_scale_,
+    const int top_count, const int bottom_count);
+
+#ifdef __cplusplus
+} /* !extern "C" */
+#endif
+
+#endif /* !PRROI_POOLING_GPU_IMPL_CUH */
+
diff --git a/lib/nn/prroi_pool/travis.sh b/lib/nn/prroi_pool/travis.sh
new file mode 100755
index 00000000..4d50d12f
--- /dev/null
+++ b/lib/nn/prroi_pool/travis.sh
@@ -0,0 +1,18 @@
+#! /bin/bash -e
+# File   : travis.sh
+# Author : Jiayuan Mao
+# Email  : maojiayuan@gmail.com
+#
+# Distributed under terms of the MIT license.
+# Copyright (c) 2017 Megvii Technology Limited.
+
+cd src
+echo "Working directory: " `pwd`
+echo "Compiling prroi_pooling kernels by nvcc..."
+nvcc -c -o prroi_pooling_gpu_impl.cu.o prroi_pooling_gpu_impl.cu -x cu -Xcompiler -fPIC -arch=sm_52
+
+cd ../
+echo "Working directory: " `pwd`
+echo "Building python libraries..."
+python3 build.py
+
diff --git a/models/models.py b/models/models.py
index 4fb9b5ee..595c2771 100644
--- a/models/models.py
+++ b/models/models.py
@@ -1,15 +1,17 @@
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 import torchvision
 from . import resnet, resnext
-from lib.nn import SynchronizedBatchNorm2d
+from lib.nn import SynchronizedBatchNorm2d, PrRoIPool2D
 
 
 class SegmentationModuleBase(nn.Module):
     def __init__(self):
         super(SegmentationModuleBase, self).__init__()
 
-    def pixel_acc(self, pred, label):
+    @staticmethod
+    def pixel_acc(pred, label):
         _, preds = torch.max(pred, dim=1)
         valid = (label >= 0).long()
         acc_sum = torch.sum(valid * (preds == label).long())
@@ -33,12 +35,17 @@ def forward(self, feed_dict, *, segSize=None):
             else:
                 pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
 
-            loss = self.crit(pred, feed_dict['seg_label'])
+            # all maps resize to batch size
+            seg_label = feed_dict['seg_label']
+            seg_label = seg_label.view(-1, seg_label.size(2), seg_label.size(3)) # (b, h, w)
+            pred = torch.cat(pred, dim=0).view(-1, pred[0].size(1), seg_label.size(1), seg_label.size(2)) # (b, c, h, w)
+            
+            loss = self.crit(pred, seg_label)
             if self.deep_sup_scale is not None:
                 loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label'])
                 loss = loss + loss_deepsup * self.deep_sup_scale
 
-            acc = self.pixel_acc(pred, feed_dict['seg_label'])
+            acc = self.pixel_acc(pred, seg_label)
             return loss, acc
         else: # inference
             pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)
@@ -46,7 +53,7 @@ def forward(self, feed_dict, *, segSize=None):
 
 
 def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
-    "3x3 convolution with padding"
+    """3x3 convolution with padding"""
     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                      padding=1, bias=has_bias)
 
@@ -59,12 +66,16 @@ def conv3x3_bn_relu(in_planes, out_planes, stride=1):
             )
 
 
-class ModelBuilder():
+class ModelBuilder:
+    def __init__(self):
+        pass
+
     # custom weights initialization
-    def weights_init(self, m):
+    @staticmethod
+    def weights_init(m):
         classname = m.__class__.__name__
         if classname.find('Conv') != -1:
-            nn.init.kaiming_normal_(m.weight.data)
+            nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
         elif classname.find('BatchNorm') != -1:
             m.weight.data.fill_(1.)
             m.bias.data.fill_(1e-4)
@@ -168,12 +179,6 @@ def build_decoder(self, arch='ppm_bilinear_deepsup',
                 fc_dim=fc_dim,
                 use_softmax=use_softmax,
                 fpn_dim=512)
-        elif arch == 'upernet_tmp':
-            net_decoder = UPerNetTmp(
-                num_class=num_class,
-                fc_dim=fc_dim,
-                use_softmax=use_softmax,
-                fpn_dim=512)
         else:
             raise Exception('Architecture undefined!')
 
@@ -306,9 +311,9 @@ def forward(self, conv_out, segSize=None):
         x = self.conv_last(x)
 
         if self.use_softmax:  # is True during inference
-            x = nn.functional.upsample(
+            x = F.upsample(
                 x, size=segSize, mode='bilinear', align_corners=False)
-            x = nn.functional.softmax(x, dim=1)
+            x = F.softmax(x, dim=1)
             return x
 
         # deep sup
@@ -316,8 +321,8 @@ def forward(self, conv_out, segSize=None):
         _ = self.cbr_deepsup(conv4)
         _ = self.conv_last_deepsup(_)
 
-        x = nn.functional.log_softmax(x, dim=1)
-        _ = nn.functional.log_softmax(_, dim=1)
+        x = F.log_softmax(x, dim=1)
+        _ = F.log_softmax(_, dim=1)
 
         return (x, _)
 
@@ -339,11 +344,11 @@ def forward(self, conv_out, segSize=None):
         x = self.conv_last(x)
 
         if self.use_softmax: # is True during inference
-            x = nn.functional.upsample(
+            x = F.upsample(
                 x, size=segSize, mode='bilinear', align_corners=False)
-            x = nn.functional.softmax(x, dim=1)
+            x = F.softmax(x, dim=1)
         else:
-            x = nn.functional.log_softmax(x, dim=1)
+            x = F.log_softmax(x, dim=1)
 
         return x
 
@@ -380,7 +385,7 @@ def forward(self, conv_out, segSize=None):
         input_size = conv5.size()
         ppm_out = [conv5]
         for pool_scale in self.ppm:
-            ppm_out.append(nn.functional.upsample(
+            ppm_out.append(F.upsample(
                 pool_scale(conv5),
                 (input_size[2], input_size[3]),
                 mode='bilinear', align_corners=False))
@@ -389,11 +394,11 @@ def forward(self, conv_out, segSize=None):
         x = self.conv_last(ppm_out)
 
         if self.use_softmax:  # is True during inference
-            x = nn.functional.upsample(
+            x = F.upsample(
                 x, size=segSize, mode='bilinear', align_corners=False)
-            x = nn.functional.softmax(x, dim=1)
+            x = F.softmax(x, dim=1)
         else:
-            x = nn.functional.log_softmax(x, dim=1)
+            x = F.log_softmax(x, dim=1)
         return x
 
 
@@ -432,7 +437,7 @@ def forward(self, conv_out, segSize=None):
         input_size = conv5.size()
         ppm_out = [conv5]
         for pool_scale in self.ppm:
-            ppm_out.append(nn.functional.upsample(
+            ppm_out.append(F.upsample(
                 pool_scale(conv5),
                 (input_size[2], input_size[3]),
                 mode='bilinear', align_corners=False))
@@ -441,9 +446,9 @@ def forward(self, conv_out, segSize=None):
         x = self.conv_last(ppm_out)
 
         if self.use_softmax:  # is True during inference
-            x = nn.functional.upsample(
+            x = F.upsample(
                 x, size=segSize, mode='bilinear', align_corners=False)
-            x = nn.functional.softmax(x, dim=1)
+            x = F.softmax(x, dim=1)
             return x
 
         # deep sup
@@ -452,10 +457,10 @@ def forward(self, conv_out, segSize=None):
         _ = self.dropout_deepsup(_)
         _ = self.conv_last_deepsup(_)
 
-        x = nn.functional.log_softmax(x, dim=1)
-        _ = nn.functional.log_softmax(_, dim=1)
+        x = F.log_softmax(x, dim=1)
+        _ = F.log_softmax(_, dim=1)
 
-        return (x, _)
+        return x, _
 
 
 # upernet
@@ -471,7 +476,8 @@ def __init__(self, num_class=150, fc_dim=4096,
         self.ppm_conv = []
 
         for scale in pool_scales:
-            self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
+            # we use the feature map size instead of input size, so scale = 1.0
+            self.ppm_pooling.append(PrRoIPool2D(scale, scale, 1.))
             self.ppm_conv.append(nn.Sequential(
                 nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
                 SynchronizedBatchNorm2d(512),
@@ -498,19 +504,24 @@ def __init__(self, num_class=150, fc_dim=4096,
             ))
         self.fpn_out = nn.ModuleList(self.fpn_out)
 
-        self.conv_last = nn.Sequential(
-            conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1),
-            nn.Conv2d(fpn_dim, num_class, kernel_size=1)
-        )
+        self.conv_last = nn.ModuleList()
+        for i in range(len(fpn_inplanes)):
+            self.conv_last.append(nn.Conv2d(fpn_dim, num_class, kernel_size=1))
 
     def forward(self, conv_out, segSize=None):
         conv5 = conv_out[-1]
 
         input_size = conv5.size()
         ppm_out = [conv5]
+
+        roi = []
+        for i in range(input_size[0]): # batch size
+            roi.append(torch.Tensor([i, 0, 0, input_size[3], input_size[2]]).view(1, -1)) # b, x0, y0, x1, y1
+        roi = torch.cat(roi, dim=0).type_as(conv5)
+        ppm_out = [conv5]
         for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
             ppm_out.append(pool_conv(nn.functional.upsample(
-                pool_scale(conv5),
+                pool_scale(conv5, roi.detach()),
                 (input_size[2], input_size[3]),
                 mode='bilinear', align_corners=False)))
         ppm_out = torch.cat(ppm_out, 1)
@@ -521,7 +532,7 @@ def forward(self, conv_out, segSize=None):
             conv_x = conv_out[i]
             conv_x = self.fpn_in[i](conv_x) # lateral branch
 
-            f = nn.functional.upsample(
+            f = F.upsample(
                 f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch
             f = conv_x + f
 
@@ -529,21 +540,17 @@ def forward(self, conv_out, segSize=None):
 
         fpn_feature_list.reverse() # [P2 - P5]
         output_size = fpn_feature_list[0].size()[2:]
-        fusion_list = [fpn_feature_list[0]]
-        for i in range(1, len(fpn_feature_list)):
-            fusion_list.append(nn.functional.upsample(
-                fpn_feature_list[i],
-                output_size,
-                mode='bilinear', align_corners=False))
-        fusion_out = torch.cat(fusion_list, 1)
-        x = self.conv_last(fusion_out)
+        output_list = []
 
-        if self.use_softmax:  # is True during inference
-            x = nn.functional.upsample(
-                x, size=segSize, mode='bilinear', align_corners=False)
-            x = nn.functional.softmax(x, dim=1)
-            return x
+        for i in range(len(fpn_feature_list)):
+            x = self.conv_last[i](fpn_feature_list[i])
+            x = F.upsample(x, size=output_size, mode='bilinear', align_corners=False)
+            output_list.append(x)
 
-        x = nn.functional.log_softmax(x, dim=1)
+        if self.use_softmax:  # is True during inference
+            output_list = [F.upsample(x, size=segSize, mode='bilinear', align_corners=False) for x in output_list]
+            output_list = [torch.softmax(x, dim=1) for x in output_list]
+            return output_list
 
-        return x
+        output_list = [torch.log_softmax(x, dim=1) for x in output_list]
+        return output_list