-
Notifications
You must be signed in to change notification settings - Fork 52
/
segnext.py
64 lines (47 loc) · 2.02 KB
/
segnext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# !/usr/bin/env python
# -- coding: utf-8 --
# @Time : 2022/9/26 14:17
# @Author : liumin
# @File : segnext.py
"""
SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation
https://arxiv.org/pdf/2209.08575.pdf
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from src.models.backbones import build_backbone
from src.models.heads import build_head
from src.losses.seg_loss import CrossEntropyLoss2d
class SegNeXt(nn.Module):
def __init__(self, dictionary=None, model_cfg=None):
super().__init__()
self.dictionary = dictionary
self.model_cfg = model_cfg
self.input_size = [1024, 2048]
self.dummy_input = torch.zeros(1, 3, self.input_size[0], self.input_size[1])
self.num_classes = len(self.dictionary)
self.category = [v for d in self.dictionary for v in d.keys()]
self.weight = [d[v] for d in self.dictionary for v in d.keys() if v in self.category]
self.setup_extra_params()
self.backbone = build_backbone(self.model_cfg.BACKBONE)
self.head = build_head(self.model_cfg.HEAD)
self.criterion = CrossEntropyLoss2d(weight=torch.from_numpy(np.array(self.weight)).float()).cuda()
def setup_extra_params(self):
self.model_cfg.HEAD.__setitem__('num_classes', self.num_classes)
def forward(self, imgs, targets=None, mode='infer', **kwargs):
batch_size, ch, _, _ = imgs.shape
feats = self.backbone(imgs)
feats = self.head(feats)
outputs = F.interpolate(feats, size=imgs.size()[2:], mode='bilinear', align_corners=False)
if mode == 'infer':
return torch.argmax(outputs, dim=1)
else:
losses = {}
losses['ce_loss'] = self.criterion(outputs, targets)
losses['loss'] = losses['ce_loss']
if mode == 'val':
return losses, torch.argmax(outputs, dim=1)
else:
return losses