In [None]:
class pyramid_layer(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch,downsampled, p_shakedrop=1.):
        super().__init__()
        stride = 2 if downsampled else 1
        seq = [nn.BatchNorm2d(in_ch),
               nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False),
               nn.BatchNorm2d(out_ch),
               nn.ReLU(inplace=True),
               nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False),
               nn.BatchNorm2d(out_ch)]
        self.branch = nn.Sequential(*seq)
        self.downsampled = downsampled
        self.shortcut = not self.downsampled and None or nn.AvgPool2d(2, padding=0,ceil_mode= True)
        self.use_shake_drop = use_shake_drop

    def forward(self, x):
        h = self.branch(x)
        if self.use_shake_drop:
            h = self.shake_drop(h)
        h0 = x if not self.downsampled else self.shortcut(x)
        pad_zero = torch.zeros((h0.size(0), abs(h.size(1) - h0.size(1)), h0.size(2), h0.size(3)), dtype=x.dtype,device=x.device)
        if h.size(1) > h0.size(1):
            h0 = torch.cat([h0, pad_zero], dim=1)
        else:
            h = torch.cat([h, pad_zero], dim=1)
        return h + h0
class PyramidNet(nn.Module):
  def __init__(self,alpha = 42,n_layers = 1,init_ch = 16):
    super().__init__()
    self.n_units = n_layers * 2
    self.p_dec_ratio = .5 / self.n_units
    self.add_rate = round(alpha / self.n_units)
    self.p = 1
    self.in_ch = init_ch
    seq = []
    in_ch = self.in_ch
    p = self.p
    
    for i in range(n_layers):
      for j in range(2):
        out_ch = in_ch + self.add_rate 
        p -=  self.p_dec_ratio 
      # appnd pyramid botttle neck here!
        seq.append(pyramid_layer(in_ch,out_ch,j % 2 == 1,False,p))
        in_ch = out_ch
    #resnet 
    self.net = nn.Sequential(nn.Conv2d(3,10,3,padding=1),nn.BatchNorm2d(10),nn.ReLU(inplace = True),nn.Conv2d(10,self.in_ch,3,padding=1),nn.BatchNorm2d(self.in_ch),nn.ReLU(inplace = True),nn.MaxPool2d(2,padding=0),*seq,nn.AvgPool2d(2, padding=0))
    self.out_ch = out_ch
    self.lin_act = nn.Sequential(nn.Linear(self.out_ch,1),nn.Sigmoid())
    
     
  def forward(self,x):
    #Fully convolutional ish 
    out = self.net(x)
    out = out.view(-1,self.out_ch)
    out = self.lin_act(out)
    #out = nn.Sigmoid()(out)
    #activate with CRF instead of softmax
    return torch.squeeze(out)
  


In [None]:
model = PyramidNet().to(device)
model.load_state_dict(torch.load('../brown-datathon/src/deepcut_model/deepcut.pth'))

with torch.no_grad():
    model.eval()
    for i,box in enumerate(first_box):
        y1,x1 = np.floor(box[:2]).astype(np.uint8)
        y2,x2 = np.ceil(box[2:]).astype(np.uint8)
        y,x = np.mgrid[y1 + 4: y2 + 4,x1 + 4:x2 + 4]
        coords = np.vstack([y.ravel(),x.ravel()]).T
        this_patches = [first_img[coord[0] - 4:coord[0] + 4,coord[1] - 4:coord[1] + 4] for _,coord in enumerate(coords)]
        this_patches = torch.from_numpy(np.stack(this_patches)).to(device).permute(0,3,1,2).float()
        preds = model(this_patches).to('cpu').data.numpy()
        neg_preds = - np.log(1 - preds + 1e-10)
        preds = - np.log(preds + 1e-10)
        d = dcrf.DenseCRF2D(*np.flip(x.shape),2)
        d.setUnaryEnergy(np.stack([neg_preds,preds]))
        d.addPairwiseGaussian(sxy = 1, compat=2)
        image_in_box = first_img[y1 + 4: y2 + 4,x1 + 4:x2 + 4]
        image_in_box = image_in_box.copy(order = 'C')
        d.addPairwiseBilateral(sxy=10, srgb=20, rgbim = that_img, compat=2)
        Q = d.inference(5)
        masks[y1:y2,x1:x2] = np.argmax(Q, axis=0).reshape(x.shape).astype(np.uint8)
np.savez_compressed('train_masks.npz',*masks)

In [None]:
device = torch.device('cuda')
model = PyramidNet().to(device)
model.load_state_dict(torch.load('../brown-datathon/src/deepcut_model/deepcut.pth'))
images = np.load('val_imgs.npz')
fg_images = images['fg_patches']
bg_images = images['bg_patches']
box_labels = np.load('val_box_labels.npz')
box_coords = np.load('val_box_coords.npz')    
labels = [label for label in box_labels.values()]
boxes_list = [box for box in box_coords.values()]
masks = []
padder = iaa.Pad(px = 10,pad_cval=255,keep_size = False)
fg_images = padder.augment_images(fg_images.astype(np.uint8))
with torch.no_grad():
    model.eval()
    for j,image in enumerate(fg_images):
        boxes = boxes_list[j]
        label = labels[j] + 1
        mask = np.zeros((256,256))
        for i,box in enumerate(boxes):            
            y1,x1 = np.floor(box[:2]).astype(np.uint16)
            y2,x2 = np.ceil(box[2:]).astype(np.uint16)
            y,x = np.mgrid[y1 + 10: y2 + 10,x1 + 10:x2 + 10]
            coords = np.vstack([y.ravel(),x.ravel()]).T
            this_patches = [image[coord[0] - 4:coord[0] + 5,coord[1] - 4:coord[1] + 5] for _,coord in enumerate(coords)]
            this_patches = torch.from_numpy(np.stack(this_patches)).to(device).permute(0,3,1,2).float()
            preds = model(this_patches).to('cpu').data.numpy()
            neg_preds = - np.log(1 - preds + 1e-10)
            preds = - np.log(preds + 1e-10)
            d = dcrf.DenseCRF2D(*np.flip(x.shape),2)
            d.setUnaryEnergy(np.stack([neg_preds,preds]))
            d.addPairwiseGaussian(sxy = 1, compat=2)
            image_in_box = image[y1 + 10: y2 + 10,x1 + 10:x2 + 10]
            image_in_box = image_in_box.copy(order = 'C')
            d.addPairwiseBilateral(sxy=10, srgb=20, rgbim = image_in_box.astype(np.uint8), compat=2)
            Q = d.inference(5)
            seg_in_box = np.argmax(Q, axis=0).reshape(x.shape).astype(np.uint8)
            seg_in_box[seg_in_box == 1] = label[i]
            mask[y1:y2,x1:x2] = seg_in_box
        masks.append(mask)
            #change here
np.savez_compressed('val_masks.npz',*masks)