Skip to content

Commit

Permalink
Fix efficientNet and add NMS in torchvision
Browse files Browse the repository at this point in the history
  • Loading branch information
v.toandm2 committed Dec 5, 2019
1 parent 8c7ca08 commit 923a51b
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 546 deletions.
17 changes: 13 additions & 4 deletions models/efficientdet.py
Expand Up @@ -35,7 +35,7 @@ def forward(self, inputs):
return classification, regression, anchors
else:
transformed_anchors = self.regressBoxes(anchors, regression)
transformed_anchors = self.clipBoxes(transformed_anchors, img_batch)
transformed_anchors = self.clipBoxes(transformed_anchors, inputs)
scores = torch.max(classification, dim=2, keepdim=True)[0]
scores_over_thresh = (scores>0.05)[0, :, 0]
if scores_over_thresh.sum() == 0:
Expand All @@ -44,8 +44,17 @@ def forward(self, inputs):
classification = classification[:, scores_over_thresh, :]
transformed_anchors = transformed_anchors[:, scores_over_thresh, :]
scores = scores[:, scores_over_thresh, :]
anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.5)
nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1)
return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]]

anchors_nms_idx = []
for i in range(transformed_anchors.size(0)):
anchors_nms_idx.append(nms(transformed_anchors[i], scores[i].view(-1), iou_threshold=0.5))

nms_scores = []
nms_class = []
for i in range(classification.size(0)):
_score, _class = classification[i, anchors_nms_idx[i], :].max(dim=1)
nms_scores.append(_score)
nms_class.append(_class)
return [nms_scores, nms_class]


1 change: 0 additions & 1 deletion models/efficientnet.py
Expand Up @@ -181,7 +181,6 @@ def extract_features(self, inputs):
num_repeat = 0
index+=1
P.append(x)
print('P len: ', len(P))
return P

def forward(self, inputs):
Expand Down
3 changes: 2 additions & 1 deletion test.py
Expand Up @@ -2,9 +2,10 @@
from models import EfficientDet

if __name__ == '__main__':
inputs = torch.randn(2, 3, 512, 512)
inputs = torch.randn(5, 3, 512, 512).cuda()

model = EfficientDet(num_classes=2, is_training=False)
model = model.cuda()
output = model(inputs)
for p in output:
print(p.size())
Expand Down
1 change: 1 addition & 0 deletions train.py
Expand Up @@ -114,6 +114,7 @@ def train():
if(iteration%100==0):
print('Epoch/Iteration: {}/{}, classification: {}, regression: {}, totol_loss: {}'.format(epoch, iteration, classification_loss.item(), regression_loss.item(), np.mean(total_loss)))
iteration+=1
torch.save(model.state_dict(), './weights/checkpoint_{}.pth'.format(epoch))

if __name__ == '__main__':
train()
115 changes: 0 additions & 115 deletions train_lr.py

This file was deleted.

115 changes: 0 additions & 115 deletions train_pytoan.py

This file was deleted.

0 comments on commit 923a51b

Please sign in to comment.