PyTorch implementation of SAR-BagNet
-
Pytorch platform for Windows
-
python 3.6+
-
The training model requires a video card with more than 12G video memory
-
opencv
-
Unzip the image. we provide the MSTAR dataset in the images folder
-
Run trian_test.py.The training process is the same as a traditional CNNs. This program includes the preprocessing operation of the data of this project, and different processing processes can be selected according to different tasks
-
utils.py can generate heatmaps of each SAR images
- Please place a trained model in the specified folder,Model_urls is the location of the model, and model_dir is the save folder for the model,for example:
model_urls = {'SAR_BagNet':''D:/SAR-bagnet/saved_model/model.pth''}
model_dir='D:/SAR-bagnet/saved_model'
The above code is in the SAR_BagNet.py file,modify the corresponding code to correspond to your file location
2.Replace the class ResNet(nn.Module): def forward(self,x)
in SAR-Bagnet.py with the following code
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
logits_list1 = []
for i in range(N):
for j in range(N):
x1=x[:,:,i,j]
x1.view(x1.size(0), -1)
logits1=self.fc(x1)
logits1=logits1[:,C]
logits_list1.append(logits1.data.cpu().numpy().copy())
logits2 = np.hstack(logits_list1)
logits2 = logits2.reshape((N, N))
if self.avg_pool:
x = nn.AvgPool2d(x.size()[2], stride=1)(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
else:
x = x.permute(0,2,3,1)
x = self.fc(x)
return x,logits2
N is the size of the heatmap, and C is the corresponding category of the heatmap
- Run utils.py to generate heatmap
If you have any questions, please contact me at 1441771519@qq.com