class Magicpoint_Model(nn.Module): ''' input:(N,1,H,W) semi:(N,1,H,W) ''' def __init__(self): super(Magicpoint_Model, self).__init__() # conv self.conv1a = nn.Conv2d(1, 64, 3, 1, 1) self.batchnorm_conv1a = nn.BatchNorm2d(64) self.conv1b = nn.Conv2d(64, 64, 3, 1, 1) self.batchnorm_conv1b = nn.BatchNorm2d(64) self.conv2a = nn.Conv2d(64, 64, 3, 1, 1) self.batchnorm_conv2a = nn.BatchNorm2d(64) self.conv2b = nn.Conv2d(64, 64, 3, 1, 1) self.batchnorm_conv2b = nn.BatchNorm2d(64) self.conv3a = nn.Conv2d(64, 64, 3, 1, 1) self.batchnorm_conv3a = nn.BatchNorm2d(64) self.conv3b = nn.Conv2d(64, 64, 3, 1, 1) self.batchnorm_conv3b = nn.BatchNorm2d(64) self.conv4a = nn.Conv2d(64, 64, 3, 1, 1) self.batchnorm_conv4a = nn.BatchNorm2d(64) self.conv4b = nn.Conv2d(64, 64, 3, 1, 1) self.batchnorm_conv4b = nn.BatchNorm2d(64) # points self.convp1 = nn.Conv2d(64, 256, 3, 1, 1) self.batchnorm_convp1 = nn.BatchNorm2d(256) self.convp2 = nn.Conv2d(256, 64, 3, 1, 1) self.batchnorm_convp2 = nn.BatchNorm2d(64) self.convp3 = nn.Conv2d(64, 32, 3, 1, 1) self.batchnorm_convp3 = nn.BatchNorm2d(32) self.convp4 = nn.Conv2d(32, 1, 1, 1, 0) self.batchnorm_convp4 = nn.BatchNorm2d(1) # others self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(2) self.unsample = nn.UpsamplingBilinear2d(scale_factor=2) self.tanh = nn.Tanh() def forward(self, x): # encode x = self.relu(x) x = self.relu(self.batchnorm_conv1a(self.conv1a(x))) x = self.relu(self.batchnorm_conv1b(self.conv1b(x))) # conv x1 = x x = self.pool(x) x = self.relu(self.batchnorm_conv2a(self.conv2a(x))) x = self.relu(self.batchnorm_conv2b(self.conv2b(x))) x2 = x x = self.pool(x) x = self.relu(self.batchnorm_conv3a(self.conv3a(x))) x = self.relu(self.batchnorm_conv3b(self.conv3b(x))) x3 = x x = self.pool(x) x = self.relu(self.batchnorm_conv4a(self.conv4a(x))) x = self.relu(self.batchnorm_conv4b(self.conv4b(x))) x = self.unsample(self.unsample(self.unsample(x) + x3) + x2) + x1 # points p = self.relu(self.batchnorm_convp1(self.convp1(x))) p = self.relu(self.batchnorm_convp2(self.convp2(p))) p = self.relu(self.batchnorm_convp3(self.convp3(p))) semi = self.relu(self.convp4(p)) semi = self.tanh(semi) return semi import numpy as np import torch from torch import nn from function import Getmap, getidxs, Get_mask, warp_points class Loss_Model(nn.Module): def __init__(self, device, W, H): super(Loss_Model, self).__init__() self.device = device self.H = H self.W = W def loss_shaperegularizer(self, semis, label_2ds): BCELoss = torch.nn.BCELoss(reduction='sum').to(self.device) loss = BCELoss(semis, label_2ds) return loss def loss_classification(self, semis, label_paths, pad): BATCH_N = len(semis) LOSS = 0 # batch_loss for semi, label_path in zip(semis, label_paths): semi = semi.squeeze() # max = semi.max() # min = semi.min() # if max - min != 0: # semi = (semi - min) / (max - min) points = np.load(label_path) points = points.astype('int') # label: n*(y,x) points = points[points[:, 0] >= pad] points = points[points[:, 0] <= (self.W - pad - 1)] points = points[points[:, 1] >= pad] points = points[points[:, 1] <= (self.H - pad - 1)] # img_loss num = len(points) if num == 0: BATCH_N -= 1 continue loss_sum = 0 for point in points: x = point[0] y = point[1] down = torch.sum(torch.exp(semi[y - pad:y + pad + 1, x - pad: x + pad + 1])) - torch.exp(semi[y][x]) up = torch.exp(semi[y][x]) loss_sum -= torch.log2_(up / down) LOSS += (loss_sum / num) return LOSS / BATCH_N def loss_twice(self, semis1, semis2, mats, label_paths): LOSS_S = 0 N = len(semis1) for semi1, semi2, label_path, mat in zip(semis1, semis2, label_paths, mats): semi1 = semi1.squeeze() semi2 = semi2.squeeze() mat = np.array(mat) mat_i = np.linalg.inv(mat) mask_1 = Get_mask(mat_i, self.H, self.W).to(self.device) mask_2 = Get_mask(mat, self.H, self.W).to(self.device) semi1 = semi1 * mask_1 semi2 = semi2 * mask_2 # points = np.load(label_path) points = points.astype('int') points1 = points[:, (1, 0)] points2 = warp_points(points1, mat).astype('int') loss_sum = 0 n = len(points) for point1, point2 in zip(points1, points2): x1 = point1[0] y1 = point1[1] x2 = point2[0] y2 = point2[1] if ((x2 >= self.W) or (y2 >= self.H) or (x2 <= 0) or (y2 <= 0)): n -= 1 continue loss_sum += torch.pow(semi1[y1][x1] - semi2[y2][x2], 2) if n == 0: N -= 1 continue LOSS_S += loss_sum / n LOSS_S = LOSS_S / N return LOSS_S def loss_forward(self, semis1, semis2=None, label_paths=None, label_2d=None, mats=None, pad=4, FLAG=[0, 1, 0]): loss_s = 0 loss_c = 0 loss_t = 0 if FLAG[0]: # 's' loss_s = self.loss_shaperegularizer(semis1, label_2d) if FLAG[1]: # 'c' loss_c = self.loss_classification(semis1, label_paths, pad) if FLAG[2]: # 't' loss_t = self.loss_twice(semis1, semis2, mats, label_paths) return loss_s, loss_c, loss_t import os.path import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from dataset import Dataset_Coco, Dataset_Syn from loss import Loss_Model from model import Magicpoint_Model # log writer = SummaryWriter('log') save_dir = 'save_model/model_coco_1' os.makedirs(save_dir, exist_ok=True) # create net device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('device:', device) Net = Magicpoint_Model().to(device) # dataset img_dir = 'dataset/COCO/img' point_dir = 'dataset/COCO/point_xy' label_dir = 'dataset/COCO/label_g' dataset = Dataset_Coco(img_dir, point_dir, label_dir, resize_HW=(240, 320)) print('dataset_len:', dataset.__len__()) # create loss LOSS_FUNCTION = Loss_Model(device, 320, 240).to(device) learning_rate = 1e-3 optimizer = torch.optim.Adam(Net.parameters(), lr=learning_rate, betas=(0.9, 0.999)) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8) # train batch_size = 16 epoch = 100 dataloader = DataLoader(dataset, batch_size, pin_memory=True, shuffle=True) Net.train() i = 0 for e in range(epoch): LOSS = 0 N = 0 flag = 0 for data in tqdm(dataloader, desc=str(e)): # loader img_2ds = data['img_2d'] inputs = data['input'] img_paths = data['img_path'] point_paths = data['point_path'] names = data['name'] label_2ds = data['label_2d'] # cuda imgs1 = inputs.to(device) # (b,1,h,w) label_2ds = label_2ds.to(device) # (b,1,h,w) # net semis1 = Net(imgs1) # (b,1,h,w) semis2 = None loss_s, loss_c, loss_t = LOSS_FUNCTION.loss_forward(semis1, semis2, point_paths, label_2ds, None, 4, [1, 0, 0]) loss = loss_s + loss_c + loss_t LOSS += loss N += 1 # log writer.add_scalar('loss_s', loss_s, i) writer.add_scalar('loss_c', loss_c, i) writer.add_scalar('loss_t', loss_t, i) writer.add_scalar('loss', loss, i) i += 1 # optimizer optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() flag += 1 if not i % 100: torch.save(Net.state_dict(), '{}/superpoint_{}.pth'.format(save_dir, i)) writer.add_scalar('LOSS', (LOSS / N), e)